#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) 2020 KuraLabs S.R.L
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Cleanly kills a process.

Sends a SIGTERM signal, waits for it to die, and if it doesn't, SIGKILL all its
subprocesses in the process tree from higher (recent) to lower (former) PID
value.

Mandatory dependencies:

- psutil: Process status and subprocesses.

Optional dependencies:

- colorlog: for colored logging.
"""

import logging
from time import time
from argparse import ArgumentParser

from psutil import Process, NoSuchProcess, TimeoutExpired, wait_procs


__author__ = 'KuraLabs S.R.L'
__email__ = 'info@kuralabs.io'
__version__ = '1.0.0'


log = logging.getLogger(__name__)


def main(pid, timeout_s, **kwargs):

    # Get process
    try:
        process = Process(pid=pid)
        name = process.name()
    except NoSuchProcess:
        log.critical('No such process PID {}'.format(pid))
        return 1

    log.info('Process PID {} identified as "{}"'.format(pid, name))

    # Terminate process
    log.info('Sending a SIGTERM to PID {} ...'.format(pid))
    start = time()
    process.terminate()

    log.info(
        'Waiting up to {}s for PID {} to terminate ...'.format(timeout_s, pid)
    )
    try:
        process.wait(timeout=timeout_s)
        log.info(
            'Process PID {} terminated succesfully after {:.2}s'.format(
                pid, time() - start,
            )
        )
        return 0

    except TimeoutExpired:
        pass

    # Kill process tree
    log.warning(
        'The process PID {} didn\'t terminate. '
        'Fetching process tree ...'.format(
            process.pid,
        )
    )
    subprocesses = process.subprocesses(recursive=True)
    log.info('Process PID {} has {} subprocesses.'.format(
        process.pid,
        len(subprocesses),
    ))

    for subprocess in reversed(sorted(subprocesses, key=lambda p: p.pid)):
        log.info(
            'Sending SIGKILL to subprocess PID {} identified by {} ...'.format(
                subprocess.pid,
                subprocess.name(),
            )
        )
        subprocess.kill()

    process.kill()

    # Wait for all process to be killed
    subprocesses.append(process)

    log.info(
        'Waiting up to {}s for PIDs {} to be killed ...'.format(
            timeout_s,
            ', '.join(map(str, sorted(
                p.pid for p in subprocesses
            ))),
        )
    )

    def on_kill(process):
        log.info('Process PID {} was killed.'.format(process.pid))

    gone, alive = wait_procs(
        subprocesses,
        timeout=timeout_s,
        callback=on_kill,
    )

    if alive:
        log.critical('Processes with PIDs {} where not killed'.format(
            ', '.join(sorted(p.pid for p in alive)),
        ))
        return 1

    log.info('All processes were killed.')
    return 0


def setup():

    parser = ArgumentParser(
        description='Cleanly kills a process'
    )

    # Standard options
    parser.add_argument(
        '-v', '--verbose',
        action='count',
        dest='verbosity',
        default=0,
        help='Increase verbosity level',
    )
    parser.add_argument(
        '--version',
        action='version',
        version='{} {}'.format(
            parser.description,
            __version__,
        ),
    )
    parser.add_argument(
        '--no-color',
        action='store_false',
        dest='colorize',
        help='Do not colorize the log output'
    )

    # Clean kill options
    parser.add_argument(
        '--timeout-s',
        type=int,
        default=30,
        help='Maximum time to wait for the process to die, in seconds.',
    )
    parser.add_argument(
        type=int,
        metavar='PID',
        dest='pid',
        help='PID of process to kill',
    )

    args = parser.parse_args()

    formatter = logging.Formatter
    logfrmt = (
        '  {asctime} | '
        '{levelname:8} | '
        '{message}'
    )

    if args.colorize:
        try:
            from colorlog import ColoredFormatter
            formatter = ColoredFormatter
            logfrmt = (
                '  {thin_white}{asctime}{reset} | '
                '{log_color}{levelname:8}{reset} | '
                '{log_color}{message}{reset}'
            )
        except ImportError:
            pass

    verbosity_levels = {
        0: logging.ERROR,
        1: logging.WARNING,
        2: logging.INFO,
        3: logging.DEBUG,
    }

    stream = logging.StreamHandler()
    stream.setFormatter(formatter(fmt=logfrmt, style='{'))

    level = verbosity_levels.get(args.verbosity, logging.DEBUG)
    logging.basicConfig(handlers=[stream], level=level)

    log.debug('Verbosity at level {}'.format(args.verbosity))

    return args


if __name__ == '__main__':
    args = setup()
    exit(main(**vars(args)))
