#!/usr/bin/env python
#
# Copyright (c) STMicroelectronics 2012
#
# This file is part of ptimeout.
#
# ptimeout is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License v2.0
# as published by the Free Software Foundation
#
# ptimeout is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# v2.0 along with ptimeout. If not, see <http://www.gnu.org/licenses/>.
#
#
# Usage: get usage with ptimeout -h
#

from __future__ import print_function
import sys

# Fail early if python version is not supported
def check_python_version():
    try:
        assert sys.hexversion >= 0x02060000
    except:  # pragma: no cover
        print >>sys.stderr, \
            'ptimeout: error: python version >= 2.6 is required'
        sys.exit(1)
check_python_version()

import os, subprocess, signal, errno, optparse, time, hashlib, logging

# Update VERSION for major.minor.patch releases.
# The sha1sum will be appended to the version string.
VERSION="1.4.0"

GRP_ENVVAR = "PTIMEOUT_GRP"

class ExitCodes:
    """ Exit codes used to feedback the parent process.
    This codes are aligned with the coreutils timeout implementation.
    Note that, as for coreutils, in case of termination with signal 9
    the TIMEOUT code can't be returned, the exit code will be (128+9).
    """
    TIMEDOUT = 124  # job timed out
    CANCELED = 125  # internal error
    CANNOT_INVOKE = 126  # error executing job
    ENOENT = 127  # couldn't find job to exec

class LocalOptionParser(optparse.OptionParser):
    """
    Overrides OptionParser.
    Exits with the correct code on error.
    Overrides version output.
    """
    def __init__(self):
        optparse.OptionParser.__init__(
            self, prog="ptimeout",
            description="ptimeout utility, "
            "stops the given COMMAND after DURATION seconds.",
            usage="%prog [options] DURATION COMMAND..."
        )
        self.disable_interspersed_args()

    def parse_args(self):
        opts, args = optparse.OptionParser.parse_args(self)
        return self.process_args(opts, args)
        

    @staticmethod
    def handle_version(option, opt, value, parser):
        with open(__file__, "rb") as f:
            sha1 = hashlib.sha1(f.read()).hexdigest()
        print("%s version %s [sha1:%s]" % (parser.prog, VERSION, sha1))
        parser.exit(0)

    def get_signum_or_exit(self, sigstr):
        """
        Return a signal number from a number or a symbolic string.
        Also handle the special case of 0 or EXIT which stand for
        a standard exit (not a signal). In this case returns 0.
        """
        if sigstr == 0 or sigstr == "EXIT":
            return 0
        return self.get_signum(sigstr)

    def get_signum(self, sigstr):
        """
        Return a signal number from a number or a symbolic string.
        """
        try:
            signum = int(sigstr)
        except ValueError:
            try:
                signum = eval("signal.SIG" + sigstr)
            except AttributeError:
                self.exit(1, ("%s: error: invalid signal spec: %s\n"
                              % (self.prog, sigstr)))
        return signum

    def process_args(self, opts, args):
        """
        Process parsed args into suitable form after some checks.
        Return a single namespace with all arguments.
        """
        opts.signal = self.get_signum(opts.signal)
        if opts.catch_signals == "":
            opts.catch_signals = set()
        else:
            opts.catch_signals = set(
                map(self.get_signum_or_exit, opts.catch_signals.split(",")))

        if opts.exit_signals == "":
            opts.exit_signals = set()
        else:
            opts.exit_signals = set(
                map(self.get_signum, opts.exit_signals.split(",")))

        if opts.ignore_signals == "":
            opts.ignore_signals = set()
        else:
            opts.ignore_signals = set(
                map(self.get_signum, opts.ignore_signals.split(",")))

        opts.catch_signals = (opts.catch_signals.
                              difference(opts.exit_signals).
                              difference(opts.ignore_signals))

        opts.exit_signals = (opts.exit_signals.
                             difference(opts.ignore_signals))

        if not opts.format in ["pid", "short", "long"]:
            self.exit(1, "%s: error: unknown list format: %s\n" %
                      (self.prog, opts.format))
        if opts.list != None:
            return opts

        if len(args) < 1:
            self.exit(1, "%s: error: missing duration argument\n" % self.prog)
        try:
            opts.duration = int(args[0])
            assert(opts.duration >= 0)
        except (ValueError, AssertionError):
            self.exit(1, "%s: error: malformed duration argument: %s\n" %
                      (self.prog, args[0]))
        if len(args) < 2:
            self.exit(1, "%s: error: missing command\n" % self.prog)
        opts.command = args[1]
        opts.arguments = args[2:]
        return opts

    def exit(self, status=0, message=None):  # pragma: uncovered
        """ Always exit with status CANCELED on error. """
        if status != 0:
            status = ExitCodes.CANCELED
        optparse.OptionParser.exit(self, status, message)


parser = LocalOptionParser()

parser.add_option("-v", "--version",
                  help="output version string",
                  action="callback",
                  callback=parser.handle_version)
parser.add_option("-d", "--debug",
                  help="debug mode",
                  action="store_true")
parser.add_option("--log-file",
                  default="&stderr",
                  help="log file, &stderr if not specified")
parser.add_option("-s", "--signal",
                  default="TERM",
                  help="signal to be sent for terminating the COMMAND. Default: TERM")
parser.add_option("-k", "--kill-after",
                  type=int, default=10,
                  help="delay for actually killing (SIGKILL) the COMMAND. Default: 10")
parser.add_option("-c", "--catch-signals",
                  default="INT,TERM,QUIT,EXIT",
                  help="signals caught by timeout and propagated. Default: INT,TERM,QUIT,EXIT")
parser.add_option("-i", "--ignore-signals",
                  default="",
                  help="signals that should be ignored by timeout. Default: none")
parser.add_option("-e", "--exit-signals",
                  default="",
                  help="signals which force monitor to exit immediately. Default: none")
parser.add_option("-l", "--list",
                  type=int,
                  help="list the processes in the given timeout session pid")
parser.add_option("-f", "--format",
                  default="short",
                  help="format of the process list, one of: pid, short, long. Default: short")

class Monitor():
    """ Monitor class for interrupting a process after a given timeout. """
    def __init__(self, args):
        """ Constructor, arguments are stored into the args object. """
        self.args = args
        self.proc = None
        self.group = None

    class TimeoutException(Exception):
        """ Timeout exception used by the alarm to notify the monitor. """
        pass

    class KillTimeoutException(Exception):
        """ Timeout exception used by the alarm to notify kill to the monitor. """
        pass

    class TerminateException(Exception):
        """ Terminate exception raised by termination signals. """
        def __init__(self, signum):
            self.signum = signum

    @staticmethod
    def filter_is_pid(pidstr):
        """
        Filter function for valid pid identifiers.
        """
        try:
            pid = int(pidstr)
            return pid >= 0
        except:
            return False

    def timeout_grps(self, pid):
        """
        Get the timeout groups for the given process, it is available through
        the GRP_ENVVAR (see definition above) envvar of the process.
        Returns the list of timeout groups for the process or [].
        """
        try:
            with open("/proc/%d/environ" % pid) as f:
                env_vars = f.read()
        except IOError:
            # Can't read, not accessible
            return []
        env_list = env_vars.strip().split('\0')
        for env in env_list:
            index = env.find(GRP_ENVVAR + "=")
            if index == 0:
                return map(int,
                           filter(self.filter_is_pid,
                                  env[(len(GRP_ENVVAR) + 1):].split(':')))
        return []

    def all_process_list(self):
        """
        Returns all visible processes pids.
        """
        pids = []
        for root, dirs, files in os.walk("/proc"):
            pids = map(int, filter(self.filter_is_pid, dirs))
            break
        return pids

    def sub_processes(self):
        """
        Construct the full list of process for this monitor.
        """
        def all_grp_processes(grp, process_list):
            """ Get all processes that match the monitor group. """
            def is_in_grp(pid):
                return grp in self.timeout_grps(pid)
            return filter(is_in_grp, process_list)

        processes = self.all_process_list()
        grp_processes = all_grp_processes(self.group, processes)
        return grp_processes

    def send_signal(self, sig):
        """
        Terminates the monitored processes with the given signal.
        All processes in the monitor hierarchy will receive the signal.
        In order to avoid race conditions when a process is created
        after the list of processes in the group is constructed, we
        iterate until no more processes in the group is found.
        Note that in the main monitor loop an alarm is setup in the
        case where the signal is not SIGKILL, thus in practice the
        loop will always be exited.
        Unless the signal is SIGKILL, each process will be signaled
        only once.
        """
        def is_not_monitor(pid):
            return pid != self.group

        logger = logging.getLogger("ptimeout")
        signaled = set()
        while True:
            try:
                to_be_killed = set(filter(is_not_monitor, self.sub_processes()))
                if not to_be_killed: break
                if logger.isEnabledFor(logging.DEBUG):
                    logger.debug("processes alive:")
                    for pid in to_be_killed:
                        logger.debug(self.format_pid(pid))
                if sig != signal.SIGKILL:
                    to_be_killed = to_be_killed.difference(signaled)
                for proc in to_be_killed:
                    logger.debug("signal pid %d with %d" % (proc, sig))
                    try:
                        os.kill(proc, sig)
                    except OSError:
                        # Process already killed
                        pass
                    signaled.add(proc)
                time.sleep(1)
            except self.TerminateException as e:
                logger.debug("ignoring signal %d" % e.signum)
                pass

    def format_pid(self, pid, fmt="short"):
        """ Returns a formatted string for the given pid. """
        assert(fmt in ["pid", "short", "long"])
        try:
            with open("/proc/%d/cmdline" % pid) as f:
                cmd_args = f.read().strip().split('\0')
        except IOError:
            cmd_args = ["<died>"]
        if fmt == "pid":
            return "%d" % (pid)
        if fmt == "short":
            return "%d %s" % (pid, cmd_args[0])
        elif fmt == "long":
            return "%d %s" % (pid, " ".join(cmd_args))

    def print_pids(self, pids, fmt="short", output=sys.stdout):
        """ Outputs the pids list informations. """
        assert(fmt in ["pid", "short", "long"])
        for pid in pids:
            print(self.format_pid(pid, fmt), file=output)

    @staticmethod
    def enum(*sequential, **named):
        """ Generate enums. """
        enums = dict(zip(sequential, range(len(sequential))), **named)
        reverse = dict((enums[key], key) for key in enums.keys())
        enums['names'] = reverse
        return type('Enum', (), enums)

    def run(self):
        """
        Runs the monitor and monitored process.
        This method returns the exit code to be passed to sys.exit.
        """

        # Setup ptimeout logger
        log_fmt = "%(levelname)s: %(name)s: %(process)d: %(message)s"
        log_lvl = logging.DEBUG if args.debug else logging.INFO
        if args.log_file == "&stderr":
            log_stream = sys.stderr
        elif args.log_file == "&stdout":
            log_stream = sys.stdout
        else:
            try:
                log_stream = open(args.log_file, "a", 1)
            except IOError as e:
                print("ptimeout: error:: can't open log file: %s" % str(e),
                      file=sys.stderr)
                return ExitCodes.CANNOT_INVOKE
        logging.basicConfig(stream = log_stream, level = log_lvl,
                            format = log_fmt)
        logger = logging.getLogger("ptimeout")

        def inhibit_signals():
            """ Inhibit all signals from the catch_signals set. """
            time.sleep(1)
            for signum in self.args.catch_signals:
                if signum != 0: signal.signal(signum, signal.SIG_IGN)

        def timeout_handler(signum, frame):
            """ Handler for the alarm. """
            inhibit_signals()
            logger.debug("timeout_handler: got signal %d" % signum)
            raise Monitor.TimeoutException()

        def timeout_kill_handler(signum, frame):
            """ Handler for the kill signal alarm. """
            inhibit_signals()
            logger.debug("timeout_kill_handler: got signal %d" % signum)
            raise Monitor.KillTimeoutException()

        def terminate_handler(signum, frame):
            """ Handler for termination through signals. """
            inhibit_signals()
            logger.debug("terminate_handler: got signal %d" % signum)
            raise Monitor.TerminateException(signum)

        def interrupt_handler(signum, frame):  # pragma: uncovered
            """ Handler for signals that require immediate exit with message. """
            logger.debug("interrupt_handler: got signal %d" % signum)
            print("ptimeout: interrupted by signal %d" % signum,
              file=sys.stderr)
            sys.exit(128 + signum)

        # Show processes for the given group
        if args.list != None:
            self.group = args.list
            self.print_pids(self.sub_processes(), args.format)
            return 0

        # Set the timeout group to the monitor pid
        self.group = os.getpid()
        grps = os.environ.get(GRP_ENVVAR, None)
        os.environ[GRP_ENVVAR] = (str(self.group) if not grps else
                                  ':'.join((str(self.group), grps)))

        # Install default handler for ^C
        if signal.SIGINT not in self.args.catch_signals:
            signal.signal(signal.SIGINT, interrupt_handler)

        # Install termination handler for all requested signals
        for signum in self.args.catch_signals:
            if signum != 0:
                signal.signal(signum, terminate_handler)

        # Install interrupt handler for all immediate exit signals
        for signum in self.args.exit_signals:
            signal.signal(signum, interrupt_handler)

        # Install null handler for all ignored signals
        for signum in self.args.ignore_signals:
            signal.signal(signum, signal.SIG_IGN)

        # Launch monitored process
        command_args = [self.args.command] + self.args.arguments
        try:
            self.proc = subprocess.Popen(command_args)
        except OSError as e:
            print("ptimeout: error: failed to run " \
                  "command: %s : %s" % (e.strerror, " ".join(command_args)),
                  file=sys.stderr)
            if e.errno == errno.ENOENT:
                return ExitCodes.ENOENT
            else:
                return ExitCodes.CANNOT_INVOKE

        # Install timer through alarm
        signal.signal(signal.SIGALRM, timeout_handler)
        if self.args.duration > 0:
            logger.debug("setting timeout alarm for %d secs" % self.args.duration)
            signal.alarm(self.args.duration)

        # Monitoring loop
        S = self.enum('RUNNING', 'SIGNALING', 'KILLING', 'COMPLETING')
        E = self.enum('NONE', 'TIMEOUT', 'KTIMEOUT', 'SIGNAL', 'EXIT')
        state = S.RUNNING
        event = E.NONE
        event_signum = None
        collect_on_exit = 0 in self.args.catch_signals
        first_signum = None
        exit_code = None
        states = set()
        events = set()
        while True:
            try:
                # Transitions
                new_state = state
                if state == S.RUNNING:
                    if event == E.TIMEOUT or event == E.SIGNAL:
                        if self.args.signal == signal.SIGKILL:
                            new_state = S.KILLING
                        else:
                            new_state = S.SIGNALING
                    elif event == E.EXIT:
                        if collect_on_exit:
                            new_state = S.SIGNALING
                        else:
                            new_state = S.COMPLETING
                elif state == S.SIGNALING:
                    if event == E.KTIMEOUT:
                        new_state = S.KILLING
                    elif event == E.EXIT:
                        new_state = S.COMPLETING
                elif state == S.KILLING:
                    if event == E.EXIT:
                        new_state = S.COMPLETING
                elif state == S.COMPLETING:
                    pass
                logger.debug("transition: %s -- %s --> %s" %
                             (S.names[state], E.names[event],
                              S.names[new_state]))
                # Actions
                state = new_state
                states.add(state)
                events.add(event)
                event = E.NONE
                if state == S.KILLING:
                    logger.debug("killing all processes with sig %d" % 9)
                    self.send_signal(9)
                elif state == S.SIGNALING:
                    logger.debug("setting kill alarm for %d secs" % self.args.kill_after)
                    signal.signal(signal.SIGALRM, timeout_kill_handler)
                    signal.alarm(self.args.kill_after)
                    logger.debug("terminating all processes with sig %d" % self.args.signal)
                    self.send_signal(self.args.signal)
                elif state == S.COMPLETING:
                    break
                # Wait for child unless already exited
                if exit_code == None:
                    exit_code = self.proc.wait()
                    logger.debug("child exit with code: %d" % exit_code)
                event = E.EXIT
            except self.TimeoutException:
                event = E.TIMEOUT
            except self.KillTimeoutException:
                event = E.KTIMEOUT
            except self.TerminateException as e:
                event = E.SIGNAL
                event_signum = e.signum
                if first_signum == None: first_signum = e.signum
        if E.SIGNAL in events:
            if S.KILLING in states:
                print("Killed after signal %d: %d: %s" % \
                      (first_signum, os.getpid(), " ".join(command_args)),
                      file=sys.stderr)
            else:
                print("Terminated after signal %d: %d: %s" % \
                      (first_signum, os.getpid(), " ".join(command_args)),
                      file=sys.stderr)
            return 128 + first_signum
        elif E.TIMEOUT in events:
            if S.KILLING in states:
                print("Killed after timeout: %d: %s" % \
                      (os.getpid(), " ".join(command_args)),
                      file=sys.stderr)
                return ExitCodes.TIMEDOUT
            else:
                print("Terminated after timeout: %d: %s" % \
                      (os.getpid(), " ".join(command_args)),
                      file=sys.stderr)
                return ExitCodes.TIMEDOUT
        return exit_code

args = parser.parse_args()
sys.exit(Monitor(args).run())
