#!/usr/bin/env python3

"""Find and kill lava workers."""

from __future__ import annotations

import argparse
import logging
import os
import signal
import sys
from contextlib import suppress
from pathlib import Path
from threading import Thread
from time import sleep, time
from typing import Any

import boto3
import psutil

from lava.config import config
from lava.lib.argparse import ArgparserExitError, ArgparserNoExit
from lava.lib.datetime import duration_to_seconds
from lava.lib.logging import setup_logging
from lava.lib.misc import dict_strip
from lava.lib.os import signum
from lava.version import __version__

__author__ = 'Murray Andrews'

PROG = os.path.splitext(os.path.basename(sys.argv[0]))[0]
LOGNAME = PROG
LOGLEVEL = 'info'
LOG = logging.getLogger(name=LOGNAME)

SIGNS_OF_LIFE_TIME = 3  # Seconds to wait to check signs of life before coup de grace
COUP_DE_GRACE_TIME = 10  # Seconds to wait for a final SIGKILL
SLEEP_TIME = 10  # Seconds to sleep between checking if worker daemons finished.
MIN_HEARTBEAT = 60  # Minimum number of seconds between heartbeats


# ------------------------------------------------------------------------------
class AutoscalingLifecycleHeartbeart(Thread):
    """
    Send AWS auto scaling lifecycle updates at regular intervals.

    :param profile:             AWS credentials profile.
    :param as_lifecycle_args:   Lifecycle event args for the boto calls.
    :param interval:            Send heartbeats at this periodicity (seconds).

    """

    # --------------------------------------------------------------------------
    def __init__(
        self,
        profile: str,
        as_lifecycle_args: dict[str, str],
        interval: int | float,
        *args,
        **kwargs,
    ):
        """Set up the timeer thread."""

        if interval < MIN_HEARTBEAT:
            raise ValueError(f'interval < {MIN_HEARTBEAT}')

        super().__init__(*args, **kwargs)
        self.daemon = True

        self._next_ts = time() + interval
        self.as_lifecycle_args = as_lifecycle_args
        self.interval = interval
        self._as_client = boto3.Session(profile_name=profile).client('autoscaling')

    # --------------------------------------------------------------------------
    def run(self):
        """Start the timer loop."""

        LOG.debug('Starting')

        while True:
            now = time()
            if now < self._next_ts:
                sleep(self._next_ts - now)
                continue

            self._next_ts += self.interval

            try:
                self._as_client.record_lifecycle_action_heartbeat(**self.as_lifecycle_args)
                LOG.info('Lifecycle hook heartbeat sent')
            except Exception as e:
                LOG.error('Lifecycle heartbeat error: %s', e)


# ------------------------------------------------------------------------------
def process_cli_args() -> argparse.Namespace:
    """
    Process the command line arguments.

    :return:    The args namespace.
    """

    argp = argparse.ArgumentParser(prog=PROG, description='Stop lava worker processes.')

    # ------------------------------
    # General options

    argp.add_argument(
        '-D',
        '--no-dispatch',
        action='store_true',
        help=(
            f'Inhibit further scheduled dispatches by creating {config("NO_DISPATCH")}.'
            ' This requires the lava-dispatcher utility to check for this file by'
            ' specifying the --check-dispatch argument.'
        ),
    )

    argp.add_argument('--profile', action='store', help='As for AWS CLI.')

    argp.add_argument(
        '--signal',
        '--sig',
        action='store',
        default=0,
        help=(
            'Send the specified signal to the lava worker processes. Can be specified'
            ' as a signal name (e.g. SIGHUP or HUP) or a signal number. The default'
            ' is 0 which only tests if the process exists. SIGHUP is interpreted as a'
            ' controlled shutdown instruction allowing running jobs to complete.'
            ' SIGTERM is interpreted as a controlled, but immediate, termination that'
            ' allows final cleanup tasks but takes no account of running jobs.'
            ' See --w, --wait.'
        ),
    )

    argp.add_argument('-v', '--version', action='version', version=__version__)

    argp.add_argument(
        '-w',
        '--wait',
        metavar='DURATION',
        action='store',
        help=(
            'Wait for up to the specified duration for the lava workers to finish'
            ' voluntarily before killing them. This requires the signal to be set'
            ' to SIGHUP / HUP as this is interpreted by the lava worker daemons'
            ' as a controlled shutdown request. '
            ' The duration must be in the form nn[X] where nn is a number and X is'
            ' one of s (seconds), m (minutes) or h (hours). If X is not specified,'
            ' seconds are assumed.'
        ),
    )

    # ------------------------------
    # AWS auto scaling actions

    aws_argp = argp.add_argument_group(
        'AWS auto scaling lifecycle options',
        description=(
            'These arguments are designed to complete an AWS auto scaling'
            ' "EC2 Instance-terminate Lifecycle Action". See the AWS CLI or'
            ' AWS auto scaling documentation for meaning and usage.'
            ' Note that the lifecycle action result is always set to CONTINUE'
            ' which means the auto scaling group _will_ terminate the instance.'
        ),
    )

    aws_argp.add_argument(
        '--auto-scaling-group-name',
        metavar='NAME',
        action='store',
        help=(
            'Send a complete-lifecycle-action signal for the specified AWS'
            ' auto scaing group. If specified, --lifecycle-hook-name is also'
            ' required.'
        ),
    )

    aws_argp.add_argument(
        '--instance-id',
        metavar='ID',
        action='store',
        help=(
            'The ID of the EC2 instance (optional).'
            ' If specified, --auto-scaling-group-name / --lifecycle-hook-name are required.'
        ),
    )

    aws_argp.add_argument(
        '--lifecycle-action-token',
        metavar='UUID',
        action='store',
        help=(
            'lifecycle action identifier (optional).'
            ' If specified, --auto-scaling-group-name / --lifecycle-hook-name are required.'
        ),
    )

    aws_argp.add_argument(
        '--lifecycle-hook-name',
        metavar='NAME',
        action='store',
        help=(
            'The name of the AWS auto scaling lifecycle hook. If specified,'
            ' --auto-scaling-group-name is also required.'
        ),
    )

    aws_argp.add_argument(
        '--lifecycle-heartbeat',
        metavar='DURATION',
        action='store',
        help=(
            'Record a heartbeat for the lifecycle action at specified intervals'
            ' (optional).'
            ' If specified, --auto-scaling-group-name / --lifecycle-hook-name are required.'
            ' THe duration must be in the form nn[X] where nn is a number and X is'
            ' one of s (seconds), m (minutes) or h (hours). If X is not specified,'
            ' seconds are assumed.'
            f' The minimum permitted value is {MIN_HEARTBEAT} seconds.'
        ),
    )

    # ------------------------------
    # Logging options

    logp = argp.add_argument_group('logging arguments')
    logp.add_argument(
        '-c',
        '--no-colour',
        '--no-color',
        dest='colour',
        action='store_false',
        default=True,
        help='Don\'t use colour in information messages.',
    )

    logp.add_argument(
        '-l',
        '--level',
        metavar='LEVEL',
        default=LOGLEVEL,
        help=(
            'Print messages of a given severity level or above. The standard'
            ' logging level names are available but debug, info, warning and'
            f' error are most useful. The Default is {LOGLEVEL}.'
        ),
    )

    logp.add_argument(
        '--log',
        action='store',
        help='Log to the specified target. This can be either a file'
        ' name or a syslog facility with an @ prefix (e.g. @local0).',
    )

    logp.add_argument(
        '--tag',
        action='store',
        default=PROG,
        help=f'Tag log entries with the specified value. The default is {PROG}.',
    )

    args = argp.parse_args()

    try:
        if args.signal != 0:
            args.signal = signum(args.signal)
    except ValueError as e:
        argp.error(str(e))

    if args.lifecycle_heartbeat:
        args.lifecycle_heartbeat = duration_to_seconds(args.lifecycle_heartbeat)
        if args.lifecycle_heartbeat < MIN_HEARTBEAT:
            argp.error(f'lifecycle-heartbeat must be >= {MIN_HEARTBEAT} seconds')

    # Verify the auto scaling options
    if any((args.auto_scaling_group_name, args.lifecycle_hook_name)):
        if not all((args.auto_scaling_group_name, args.lifecycle_hook_name)):
            argp.error('Inconsistent auto scaling arguments')
        if not any((args.instance_id, args.lifecycle_action_token)):
            argp.error('At least one of --instance-id or --lifecycle-action-token are required')
    elif any((args.instance_id, args.lifecycle_action_token, args.lifecycle_heartbeat)):
        argp.error('Inconsistent auto scaling arguments')

    return args


# ------------------------------------------------------------------------------
def process_worker_args(argv: list[str]) -> argparse.Namespace:
    """
    Process the lava worker arguments.

    Only the ones we're interested in are handled. Others are ignored.

    :param argv:    Like argv forr the worker.
    :return:        The args namespace.
    """

    argp = ArgparserNoExit()

    argp.add_argument('-r', '--realm', action='store')
    argp.add_argument('-w', '--worker', action='store')

    return argp.parse_known_args(argv)[0]


# ------------------------------------------------------------------------------
# noinspection PyUnresolvedReferences
def get_lava_workers() -> list[dict[str, Any]]:
    """
    Get process information for lava worker processes.

    :return:    A list of dicts containing 'cmdline' and 'pid'.
    """

    worker_info = []
    for proc in psutil.process_iter(attrs=['cmdline', 'pid']):
        cmd = proc.info['cmdline']
        if not cmd or len(cmd) < 2:
            continue

        if 'python' in cmd[0].lower() and cmd[1].endswith(('/lava-worker', '/lava-worker.py')):
            try:
                worker_args = process_worker_args(cmd[2:])
            except ArgparserExitError:
                continue
            if not worker_args.realm or not worker_args.worker:
                continue

            # OK looks like a real worker daemon.
            proc.info['name'] = f'{worker_args.realm}-{worker_args.worker}'
            worker_info.append(proc.info)

    return worker_info


# ------------------------------------------------------------------------------
def proc_exists(pid: int) -> bool:
    """
    Check if a process exists.

    :param pid:     Process ID.
    :return:        True if it exists, False otherwise.
    """

    try:
        os.kill(pid, 0)
    except ProcessLookupError:
        return False

    return True


# ------------------------------------------------------------------------------
def wait_and_terminate_procs(
    procs: list[dict[str, Any]], wait: int | float, sig: int = signal.SIGTERM
) -> None:
    """
    Wait for processes to terminate on their own and then kill them.

    :param procs:   A list of proc_info dicts from psutil process iterator.
    :param sig:     The termination signal to send (SIGTERM or SIGKILL usually).
                    Default is SIGTERM. This will be followed up shortly after
                    by a SIGKILL if the process is stubborn.
    :param wait:    Number of seconds to wait for the process to finish on its
                    own before terminating it.
    """

    deadline = time() + wait

    running_procs = {p['pid'] for p in procs}

    while running_procs and time() < deadline:
        sleep(SLEEP_TIME)
        for proc_info in procs:
            if proc_info['pid'] not in running_procs:
                continue

            if not proc_exists(proc_info['pid']):
                LOG.info('{name} ({pid}): finished'.format(**proc_info))
                running_procs.discard(proc_info['pid'])
            else:
                LOG.debug('{name} ({pid}): running'.format(**proc_info))

    if not running_procs:
        return

    # Terminate any processes still running
    for proc_info in (p for p in procs if p['pid'] in running_procs):
        LOG.warning('{name} ({pid}): terminating'.format(**proc_info))
        with suppress(ProcessLookupError):
            os.kill(proc_info['pid'], sig)

    # Time to get serious -- kill anything remaining
    sleep(SIGNS_OF_LIFE_TIME)
    LOG.debug('Checking for signs of life')
    running_procs = [p for p in procs if proc_exists(p['pid'])]
    if not running_procs:
        return

    LOG.debug('Waiting to deliver coup de grâce')
    sleep(COUP_DE_GRACE_TIME)
    for proc_info in running_procs:
        LOG.warning('{name} ({pid}): killing'.format(**proc_info))
        with suppress(ProcessLookupError):
            os.kill(proc_info['pid'], signal.SIGKILL)


# ------------------------------------------------------------------------------
def main() -> int:
    """
    Do the business.

    :return:    status
    """

    setup_logging(LOGLEVEL, name=LOGNAME, prefix=PROG)
    args = process_cli_args()
    setup_logging(args.level, name=LOGNAME, target=args.log, colour=args.colour, prefix=args.tag)

    lava_workers = get_lava_workers()
    if not lava_workers:
        LOG.info('No lava workers found')
        return 0

    # ----------------------------------------
    # Inhibit schedule dispatch events
    if args.no_dispatch:
        LOG.info(f'Creating {config("NO_DISPATCH")}')
        Path(config('NO_DISPATCH')).touch()

    # ----------------------------------------
    # Prep for auto scaling lifecycle activity
    as_client = None
    as_lifecycle_args = None
    if args.auto_scaling_group_name:
        as_client = boto3.Session(profile_name=args.profile).client('autoscaling')
        as_lifecycle_args = dict_strip(
            {
                'LifecycleHookName': args.lifecycle_hook_name,
                'AutoScalingGroupName': args.auto_scaling_group_name,
                'LifecycleActionToken': args.lifecycle_action_token,
                'InstanceId': args.instance_id,
            }
        )
        if args.lifecycle_heartbeat:
            AutoscalingLifecycleHeartbeart(
                args.profile,
                as_lifecycle_args,
                duration_to_seconds(args.lifecycle_heartbeat),
                name='heartbeat',  # noqa
            ).start()

    # ----------------------------------------
    # Two phase shutdown of lava warkers -- warning then kill (after wait period)
    for proc_info in lava_workers:
        pid = proc_info['pid']
        LOG.debug('%s (%d): %s', proc_info['name'], pid, ' '.join(proc_info['cmdline'][1:]))
        try:
            os.kill(pid, args.signal)
            LOG.info('{name} [{pid}] sent signal {sig}'.format(sig=args.signal, **proc_info))
        except ProcessLookupError:
            LOG.info(f'PID {pid}: No such process')
        except Exception as e:
            LOG.error(f'PID {pid}: {e}')

    if args.wait:
        wait_seconds = duration_to_seconds(args.wait)
        LOG.info('Workers have %s seconds to terminate', wait_seconds)
        wait_and_terminate_procs(lava_workers, wait_seconds)
        LOG.info('Workers shutdown complete')

    # ----------------------------------------
    # Complete the auto scaling lifecycle action
    if as_client:
        LOG.info('Completing auto scaling lifecycle action')
        as_client.complete_lifecycle_action(LifecycleActionResult='CONTINUE', **as_lifecycle_args)

    return 0


# ------------------------------------------------------------------------------
if __name__ == '__main__':
    # Uncomment for debugging
    # exit(main())  # noqa: ERA001
    try:
        exit(main())
    except Exception as ex:
        LOG.error('%s', ex)
        exit(1)
    except KeyboardInterrupt:
        LOG.warning('Interrupt')
        exit(2)
