#!/usr/bin/env python3

"""
Lava worker.

Read dispatch messages from a worker queue and process them.

"""

# ..............................................................................
# region imports
# ..............................................................................

from __future__ import annotations

import argparse
import atexit
import json
import logging
import os
import platform
import resource
import signal
import socket
import sys
import time
from pathlib import Path
from queue import Queue
from shutil import rmtree
from threading import Thread
from typing import Any, Callable, NoReturn

import boto3
import dateutil.parser
import docker
import psutil

from lava.config import LOGLEVEL, LOGNAME, __config__, config, config_load
from lava.event import EventRecorder
from lava.lava import run_job
from lava.lavacore import DEFER_ON_EXIT, LavaError, ThreadMonitor, dispatch, get_job_spec
from lava.lib.aws import ec2_instance_id
from lava.lib.daemon import daemonify
from lava.lib.datetime import duration_to_seconds, now_tz
from lava.lib.logging import JsonFormatter, setup_logging
from lava.lib.misc import Defer, Task, dict_check, str2bool
from lava.version import __version__

# ..............................................................................
# endregion imports
# ..............................................................................


# ..............................................................................
# region constants and globals
# ..............................................................................

__author__ = 'Murray Andrews'

PROG = os.path.splitext(os.path.basename(sys.argv[0]))[0]
LOG = logging.getLogger(name=LOGNAME)

LONG_POLL_DEFAULT = 10  # Wait 10 seconds for SQS messages to arrive
LONG_POLL_MAX = 20  # Max wait time
THREADS = 5
RETRIES = 2  # No. retries to process an event before giving up on it.
SQS_POLL_SLEEP = 10  # Sleep between SQS polls when no messages available.
BUSY_SLEEP = 5  # Sleep for this many seconds if all threads busy.
HEARTBEAT_MIN = 30  # Minimum time between heartbeat messages
WORKER_QUEUE_JOIN_WAIT = 3  # 30  # Seconds to wait on each iteration while joining on job queue
ZOMBIE_STRIKES = 3  # Number of checks for job tmp dirs before exiting.
WORKER_THREAD_NAME = 'worker-'  # Worker thread names have this prefix

AWS_API_BACKOFF = 1.0
AWS_API_RETRIES = 4

JOB_DISPATCH_REQUIRED_FIELDS = {'realm', 'worker', 'job_id', 'run_id', 'ts_dispatch'}
JOB_DISPATCH_OPTIONAL_FIELDS = {'parameters', 'globals'}

SOFT_EXIT_SIGNALS = (signal.SIGHUP, signal.SIGINT)
HARD_EXIT_SIGNALS = (signal.SIGTERM,)
interrupted = 0

# Fields emitted when JSON logging is enabled
JSON_LOG_FIELDS = {
    'localtime': 'asctime',
    'timestamp': 'isotime',  # This is a custom extra in the record
    'level': 'levelname',
    'message': 'message',
    'thread': 'threadName',
    'pid': 'process',
}

pid = None

# ..............................................................................
# endregion constants and globals
# ..............................................................................


# ------------------------------------------------------------------------------
class LavaQueueTimeoutError(Exception):
    """For join timeouts."""

    pass


class LavaQueue(Queue):
    """Same as normal queue but with a timeout option on join()."""

    def join(self, timeout: float = None) -> None:
        """Join on a queue with an optional timeout."""

        if not timeout:
            return super().join()

        with self.all_tasks_done:
            end_time = time.time() + timeout
            while self.unfinished_tasks:
                remaining_time = end_time - time.time()
                if remaining_time <= 0:
                    raise LavaQueueTimeoutError()

                self.all_tasks_done.wait(remaining_time)

        return None


# ------------------------------------------------------------------------------
def jumpstart(realm, worker, aws_session=None):
    """
    Get enabled lavasched jobs aimed at this worker and dispatch them.

    This allows a worker node to automatically bootstrap its scheduling.

    :param realm:           The realm name.
    :param worker:          Name of this worker.
    :param aws_session:     A boto3 Session object. If not specified a default is
                            created.
    :type realm:            str
    :type worker:           str
    :type aws_session:      boto3.Session
    """

    if not aws_session:
        aws_session = boto3.Session()

    job_table_name = f'lava.{realm}.jobs'
    dynamo = aws_session.client('dynamodb')
    paginator = dynamo.get_paginator('scan')

    response_iterator = paginator.paginate(
        TableName=job_table_name,
        FilterExpression='#type = :type AND #worker = :worker AND #enabled = :enabled',
        ExpressionAttributeNames={
            '#type': 'type',
            '#worker': 'worker',
            '#enabled': 'enabled',
            '#job_id': 'job_id',
        },
        ExpressionAttributeValues={
            ':type': {'S': 'lavasched'},
            ':worker': {'S': worker},
            ':enabled': {'BOOL': True},
        },
        ProjectionExpression='#job_id',
    )

    jumpstart_delay = config('JUMPSTART_DELAY', duration_to_seconds)
    for response in response_iterator:
        for item in response['Items']:
            try:
                job_id = item['job_id']['S']
            except KeyError:
                LOG.warning('Jumpstart: Bad job spec: %s', item, extra={'event_type': 'worker'})
                continue

            LOG.debug(
                'Jumpstart: Dispatching %s to %s', job_id, worker, extra={'event_type': 'worker'}
            )

            try:
                dispatch(realm, job_id, worker, delay=jumpstart_delay, aws_session=aws_session)
            except LavaError as e:
                LOG.error(
                    'Jumpstart: Cannot dispatch %s: %s', job_id, e, extra={'event_type': 'worker'}
                )
                continue
            else:
                LOG.info(
                    'Jumpstart: Dispatched %s to %s', job_id, worker, extra={'event_type': 'worker'}
                )


# ------------------------------------------------------------------------------
def process_cli_args() -> argparse.Namespace:
    """
    Process the command line arguments.

    :return:    The args namespace.
    """

    argp = argparse.ArgumentParser(prog=PROG, description='Lava job worker.')

    # ------------------------------
    # General options

    argp.add_argument(
        '-b',
        '--heartbeat',
        action='store',
        type=int,
        default=0,
        help=(
            'Emit a heartbeat log message every this many seconds. A value of 0 (the default)'
            f' disables heartbeats. If specified, a minimum of {HEARTBEAT_MIN} seconds is imposed.'
        ),
    )

    argp.add_argument('--profile', action='store', help='As for AWS CLI.')

    argp.add_argument(
        '-s',
        '--sleep',
        action='store',
        type=int,
        default=SQS_POLL_SLEEP,
        help=(
            'Sleep for this many seconds between SQS poll attempts when no messages are available.'
            f' Note that the -w/--wait time is outside this sleep time. Default {SQS_POLL_SLEEP}.'
        ),
    )

    argp.add_argument(
        '--retries',
        action='store',
        type=int,
        default=RETRIES,
        help=(
            'Maximum number of retry attempts to process a job event. If an event is not'
            ' processed after this many retries, it is discarded. Note that it is possible'
            ' for an event to be processed more than once. If set to a negative number,'
            f' retry limiting is disabled. Default {RETRIES}.'
        ),
    )

    argp.add_argument(
        '--root',
        action='store_true',
        help=(
            'By default, the lava worker cannot be run as root. It does not need to'
            ' and it is not safe to do so. This option allows that restriction to be'
            ' overridden. It is not a good idea.'
        ),
    )

    argp.add_argument(
        '-t',
        '--threads',
        action='store',
        type=int,
        default=THREADS,
        help=f'Run this many threads. Default {THREADS}.',
    )

    argp.add_argument('-v', '--version', action='version', version=__version__)

    argp.add_argument(
        '--wait',
        action='store',
        type=int,
        default=LONG_POLL_DEFAULT,
        help=(
            f'Wait this many seconds for the SQS queue to provide messages (long polling).'
            f' Must be in the range 0 .. {LONG_POLL_MAX}. Default is {LONG_POLL_DEFAULT} seconds.'
        ),
    )

    # ------------------------------
    # run mode options
    run_mode_group = argp.add_argument_group(
        'run mode arguments',
        description='If none of the following are specified, run in the foreground.',
    )
    run_mode_args = run_mode_group.add_mutually_exclusive_group()

    run_mode_args.add_argument(
        '--batch', action='store_true', help='Run a single batch and exit when queue is empty.'
    )

    run_mode_args.add_argument('-d', '--daemon', action='store_true', help='Run as a daemon.')

    # ------------------------------
    # Job options

    jobp = argp.add_argument_group('job options')

    jobp.add_argument(
        '--dev',
        action='store_true',
        help='Run in developer mode. This is passed through to the job'
        ' hamdlers. Its up to them to decide what they do with this.',
    )

    jobp.add_argument(
        '-j',
        '--jump-start',
        dest='jumpstart',
        action='store_true',
        help='Jump start the scheduler by running all lavasched jobs aimed at this worker.',
    )

    jobp.add_argument(
        '-q',
        '--queue',
        action='store',
        help='AWS SQS queue name. If not specified, the queue name'
        ' is derived from the realm and worker name.',
    )

    jobp.add_argument('-r', '--realm', required=True, action='store', help='Lava realm name.')

    jobp.add_argument(
        '-w',
        '--worker',
        required=True,
        action='store',
        help='Lava worker name. The worker must be a member of the specified realm.',
    )

    # ------------------------------
    # Logging options

    logp = argp.add_argument_group('logging arguments')
    logp.add_argument(
        '-c',
        '--no-colour',
        '--no-color',
        dest='colour',
        action='store_false',
        default=True,
        help='Don\'t use colour in information messages.',
    )

    logp.add_argument(
        '-l',
        '--level',
        metavar='LEVEL',
        default=LOGLEVEL,
        help=(
            'Emit messages of a given severity level or above. The standard logging level names'
            ' are available but debug, info, warning and error are most useful.'
            f' The Default is {LOGLEVEL}.'
        ),
    )

    logp.add_argument(
        '--log-json',
        action='store_true',
        help=(
            'Log messages in JSON format. This is particularly useful when log'
            ' messages end up in CloudWatch logs as it simplifies searching.'
        ),
    )

    logp.add_argument(
        '--log',
        action='store',
        help='Log to the specified target. This can be either a file'
        ' name or a syslog facility with an @ prefix (e.g. @local0).',
    )

    logp.add_argument(
        '--tag',
        action='store',
        default=PROG,
        help=f'Tag log entries with the specified value. The default is {PROG}.',
    )

    args = argp.parse_args()

    if not 0 <= args.wait <= LONG_POLL_MAX:
        argp.error(f'Argument --wait: value must be in the range 0 .. {LONG_POLL_MAX}')

    if args.heartbeat and args.heartbeat < HEARTBEAT_MIN:
        argp.error(f'Argument -b/--heartbeat: value must be 0 or greater than {HEARTBEAT_MIN}')

    return args


# ------------------------------------------------------------------------------
def attempt(
    action: Callable,
    description: str = None,
    retries: int = 0,
    backoff: float = 1.0,
    args: list[Any] = None,
    kwargs: dict[str, Any] = None,
) -> Any:
    """
    Attempt to perform some action and retry on failure.

    :param action:      A callable action.
    :param description: A description of the action.
    :param backoff:     Sleep this many more seconds between each attempt.
    :param retries:     Number of retries before giving up.
    :param args:        Positional arguments for the action.
    :param kwargs:      Keyword arguments for the action.

    :return:            Whatever the action returns.

    :raise Exception:   If the action fails after all attempts.
    """

    if not args:
        args = []
    if not kwargs:
        kwargs = {}
    if not description:
        description = str(action)

    for n in range(retries + 1):
        time.sleep(n * backoff)
        try:
            return action(*args, **kwargs)
        except Exception as e:
            LOG.warning(f'{description}: {e}', extra={'event_type': 'worker'})
            continue

    raise Exception(f'{description}: Too many failed attempts - abort')


# ..............................................................................
# region signal handlers
# ..............................................................................


# ------------------------------------------------------------------------------
# noinspection PyUnusedLocal
def terminate(signum: int, stack_frame: Any) -> None:
    """
    Handle an interrupt signal to exit immediately.

    This at least allows the atexit() tasks to run.

    :param signum:      Signal numbeer.
    :param stack_frame: Stack frame.
    :return:
    """

    LOG.info(f'Signal {signum} -- forced shutdown', extra={'event_type': 'worker'})
    signal.signal(signum, signal.SIG_DFL)
    sys.exit(143)


# ------------------------------------------------------------------------------
# noinspection PyUnusedLocal
def interrupt(signum: int, stack_frame: Any) -> None:
    """
    Handle an interrupt signal to cause a soft exit.

    :param signum:      Signal numbeer.
    :param stack_frame: Stack frame.

    """

    global interrupted

    LOG.info(f'Signal {signum} -- preparing for shutdown', extra={'event_type': 'worker'})

    interrupted = signum
    for signum in SOFT_EXIT_SIGNALS:
        signal.signal(signum, terminate)


# ------------------------------------------------------------------------------
# noinspection PyUnusedLocal
def worker_internal_info(signum: int, stack_frame: Any) -> None:
    """
    Log some worker internal information.

    :param signum:      Signal numbeer.
    :param stack_frame: Stack frame.

    """

    deferred_tasks = Defer.on_event(DEFER_ON_EXIT)
    LOG.info(
        'Current thread status: %s', ThreadMonitor().thread_status, extra={'event_type': 'worker'}
    )
    LOG.info('Deferred event count: %s', len(deferred_tasks.tasks), extra={'event_type': 'worker'})
    for task_id, task in deferred_tasks.tasks.items():
        LOG.info('Task %d: %s', task_id, task, extra={'event_type': 'worker'})

    cfg = (f'{k}={config(k)}' for k in sorted(__config__))
    LOG.info('Worker config: %s', ' '.join(cfg), extra={'event_type': 'worker'})


# ..............................................................................
# endregion signal handlers
# ..............................................................................


# ------------------------------------------------------------------------------
def main() -> int:
    """
    Show time.

    :return:    status
    """

    setup_logging(LOGLEVEL, name=LOGNAME, prefix=PROG)
    args = process_cli_args()

    if os.geteuid() == 0 and not args.root:
        raise Exception('You should not run the lava worker as root')

    if args.daemon:
        daemonify(pidfile=(args.realm + '.' + args.worker).replace('/', '-') + '.pid')

    global pid
    pid = os.getpid()

    log_formatter = None
    if args.log_json:
        log_formatter = JsonFormatter(
            fields=JSON_LOG_FIELDS,
            extra={
                'event_source': PROG,
                'realm': args.realm,
                'worker': args.worker,
                'host': platform.node(),
                'tag': args.tag,
            },
            # Important for rsyslog to have syslogtag prefix or it wrecks the JSON.
            tag=args.tag if args.tag and args.log and args.log.startswith('@') else None,
        )

    setup_logging(
        args.level,
        name=LOGNAME,
        target=args.log,
        colour=args.colour,
        prefix=args.tag,
        formatter=log_formatter,
    )

    # Setup for final cleanup
    atexit.register(Defer.on_event(DEFER_ON_EXIT).run, logger=LOG)

    aws_session = boto3.Session(profile_name=args.profile)
    sqs_queue_name = '-'.join(['lava', args.realm, args.worker]) if not args.queue else args.queue
    sqs_queue = aws_session.resource('sqs').get_queue_by_name(QueueName=sqs_queue_name)
    _ = ThreadMonitor()  # Create our singleton thread monitor

    realm_info = config_load(args.realm, aws_session=aws_session)

    # Create the temp directory for the worker. The put_metrics thread needs it.
    tmpdir = (Path(config('TMPDIR')) / args.realm / args.worker).with_suffix(f'.{os.getpid()}')
    tmpdir.mkdir(exist_ok=True, parents=True)
    Defer.on_event(DEFER_ON_EXIT).add(
        Task(f'Remove {tmpdir}', rmtree, args=[tmpdir], kwargs={'ignore_errors': True})
    )
    # Add a little sleep before before we blow away the tmpdir tree on exit
    Defer.on_event(DEFER_ON_EXIT).add(Task('Deep breath', time.sleep, args=[1]))

    # Create the temp directory for the worker. The put_metrics thread needs it.
    check_for_zombies_on_exit = config('CHECK_FOR_ZOMBIES', str2bool)

    # ----------------------------------------
    # Start the event logger thread.
    # noinspection PyTypeChecker
    EventRecorder(
        worker=args.worker, profile=args.profile, realm_info=realm_info, name='event', daemon=True
    ).start()

    # ----------------------------------------
    # Create an internal job queue and start worker
    # threads. The main thread pulls jobs off the SQS queue and puts them on the
    # internal queue for the worker threads.

    job_queue = LavaQueue(maxsize=args.threads)

    for t in range(args.threads):
        # noinspection PyTypeChecker
        JobRunner(
            worker=args.worker,
            job_queue=job_queue,
            profile=args.profile,
            realm_info=realm_info,
            tmpdir=tmpdir,
            dev_mode=args.dev,
            name=f'{WORKER_THREAD_NAME}{t:02}',
            daemon=True,
        ).start()

    # ----------------------------------------
    # Start the heartbeat thread
    if args.heartbeat:
        worker = Thread(
            target=heartbeat,
            name='heartbeat',
            args=(args.realm, args.worker, job_queue, sqs_queue, args.heartbeat, tmpdir),
        )
        worker.daemon = True
        worker.start()

    # ----------------------------------------
    # Start the put_metrics thread
    if config('CW_METRICS_WORKER', str2bool):
        worker = Thread(
            target=put_metrics,
            name='metrics',
            args=(
                args.realm,
                args.worker,
                tmpdir,
            ),
            kwargs={
                'period': config('CW_METRICS_PERIOD', duration_to_seconds),
                'aws_session': aws_session,
            },
        )
        worker.daemon = True
        worker.start()

    # ----------------------------------------
    if args.jumpstart:
        jumpstart(args.realm, args.worker, aws_session=aws_session)

    # ----------------------------------------
    # Signal handlers
    for signum in SOFT_EXIT_SIGNALS:
        # Catch some signals for a soft exit
        signal.signal(signum, interrupt)
    for signum in HARD_EXIT_SIGNALS:
        # It's all over red rover
        signal.signal(signum, terminate)

    signal.signal(signal.SIGUSR1, worker_internal_info)

    # ----------------------------------------
    # Poll for messages on the SQS queue and place them on the internal worker queue
    LOG.info('Worker ready ...', extra={'event_type': 'worker'})
    sleep_time = max(0, args.sleep)
    while not interrupted:
        if job_queue.full():
            # Don't retrieve messages from SQS if our internal worker queue is full
            time.sleep(BUSY_SLEEP)
            continue

        LOG.debug('Polling SQS')
        job_messages = attempt(
            sqs_queue.receive_messages,
            backoff=AWS_API_BACKOFF,
            retries=AWS_API_RETRIES,
            description='sqs_queue.receive_messages',
            kwargs={
                'AttributeNames': ['All'],
                'MessageAttributeNames': ['All'],
                'WaitTimeSeconds': args.wait,
                'MaxNumberOfMessages': 1,
            },
        )

        if job_messages:
            for job in job_messages:
                attempt_number = int(job.attributes['ApproximateReceiveCount'])

                if args.retries >= 0 and attempt_number > args.retries + 1:
                    LOG.warning(
                        f'Job dispatch message {job.message_id}:'
                        f' Retry limit exceeded - deleted on cycle {attempt_number}',
                        extra={'event_type': 'worker'},
                    )
                    job.delete()
                else:
                    job_queue.put(job)
        elif args.batch:
            LOG.info('No more messages in batch -- finishing', extra={'event_type': 'worker'})
            break
        elif not interrupted and sleep_time:
            LOG.debug('SQS poll sleep %d', sleep_time)
            time.sleep(sleep_time)

    LOG.info(
        'Waiting for worker threads to complete their work before finishing',
        extra={'event_type': 'worker'},
    )

    # ----------------------------------------
    # HACK WARNING
    #
    # It's possible in rare circumstances for threads to go into a zombie state
    # for an extended period.  (Yes, this should be fixed.) The consequence of
    # this is that the worker can hang when shutting down waiting for the zombie
    # to die. This can be an issue in production operation, as it then relies on
    # the auto scaler to timeout and kill the worker (which it will do just
    # fine). So, we have the option of checking if there are any job tmp
    # directories present and, if not, assume the worker can exit.

    # But, there is a potential race condition here. To minimise the risk, the
    # directory check is performed multiple times before concluding the job
    # threads are idle and the worker can exit.
    #
    # This process can be disabled by setting the lava config setting
    # `CHECK_FOR_ZOMBIES` to False.
    #
    # Why, you ask in your relentless pursuit of truth, would the thread be inert
    # end yet there is no job tmp directory? For further study. It's certainly
    # not ideal but it is the end-game at this point. And it's still better
    # than that "Avenger's End-game" debacle.

    zombie_strikes = ZOMBIE_STRIKES
    while True:
        try:
            LOG.debug('About to join on the job queue')
            job_queue.join(WORKER_QUEUE_JOIN_WAIT)
            LOG.info('Worker threads completed normally', extra={'event_type': 'worker'})
        except LavaQueueTimeoutError:
            LOG.info('Still waiting for worker threads to finish', extra={'event_type': 'worker'})
            if check_for_zombies_on_exit:
                LOG.debug('Checking for job working dirs -- %d strikes left', zombie_strikes)
                if dirs := [d for d in tmpdir.iterdir() if d.is_dir()]:
                    LOG.debug('Found %d job dirs in %s -- cannot exit yet', len(dirs), tmpdir)
                    zombie_strikes = ZOMBIE_STRIKES  # Start again
                    continue
                if not zombie_strikes:
                    LOG.info(
                        'No job working dirs found -- Possible zombie threds -- exiting',
                        extra={'event_type': 'worker'},
                    )
                    break
                zombie_strikes -= 1
                LOG.debug('No job working dirs found -- %d strikes left', zombie_strikes)
        else:
            # Threads finished normally
            break

    # Allow the daemon threads to finish before destroying everything
    time.sleep(1)
    LOG.info('Finishing', extra={'event_type': 'worker'})
    return 0


# ..............................................................................
# region threadworkers
# ..............................................................................


# ------------------------------------------------------------------------------
class JobRunner(Thread):
    """
    Thread worker to process job dispatch messages extracted from SQS.

    If the messages are processed successfully, or can never be processed
    successfully, they are removed from the SQS queue. If they raise a
    EventRetryException, they are not removed from the queue so they can be
    retried once the SQS message visibility timeout expires.

    Message body looks like this...

    .. code:: python

        {
            "realm": "...",
            "worker": "...",
            "job_id": <JOB-ID>,
            "run_id": <UUID>,
            "parameters": {
                ... optional ...
            },
            "globals": {
                ... optional ...
            }
        }

    :param worker:      Name of this lava worker.
    :param job_queue:   An internal (Python) queue containing messages extracted
                        from SQS by the main thread. Each one represents a job
                        that needs to be processed.
    :param profile:     AWS profile name (for credentials selection)
    :param realm_info:  A dictionary of realm wide parameters.
    :param tmpdir:      Worker temp directory. Job temp directories get created
                        in here. This must already exist.
    :param dev_mode:    Passed to the job handlers. Its up to them to decide
                        what they do with this. Default False.

    """

    # --------------------------------------------------------------------------
    def __init__(
        self,
        worker: str,
        job_queue: Queue,
        profile: str,
        realm_info: dict[str, Any],
        tmpdir: str | Path,
        dev_mode: bool = False,
        *args,
        **kwargs,
    ):
        """Create a job runner worker."""

        super().__init__(*args, **kwargs)

        self.realm = realm_info['realm']
        self.realm_info = realm_info
        self.worker = worker
        self.job_queue = job_queue
        self._dispatch_msg = None
        self.tmpdir = tmpdir
        self.dev_mode = dev_mode

        # Boto sessions are not thread safe so need one per thread.
        self.aws_session = boto3.Session(profile_name=profile)

        job_table_name = 'lava.' + self.realm + '.jobs'
        try:
            self.job_table = self.aws_session.resource('dynamodb').Table(job_table_name)
        except Exception as e:
            raise Exception(f'Cannot get DynamoDB table {job_table_name} - {e}')

    # --------------------------------------------------------------------------
    def get_next_job(self) -> dict[str, Any]:
        """
        Fetch the next job and validate it. Call self.job_done() when complete.

        :return:        The validated job spec for the next job.

        :raise LavaError: On error
        """

        self._dispatch_msg = self.job_queue.get()
        LOG.debug('Job: %s', self._dispatch_msg.body)

        # ----------------------------------------
        # Extract the job dispatch request
        try:
            job_dispatch = json.loads(self._dispatch_msg.body)  # type: dict
        except Exception as e:
            self.job_done()
            raise LavaError(f'Malformed dispatch message: {e}: {self._dispatch_msg.body}')

        # ----------------------------------------
        # Validate the job dispatch request.
        try:
            dict_check(
                job_dispatch,
                required=JOB_DISPATCH_REQUIRED_FIELDS,
                optional=JOB_DISPATCH_OPTIONAL_FIELDS,
            )
        except ValueError as e:
            self.job_done()
            raise LavaError(f'Malformed dispatch message: {e}')

        try:
            job_dispatch['ts_dispatch'] = dateutil.parser.parse(job_dispatch['ts_dispatch'])
        except ValueError as e:
            self.job_done()
            raise LavaError(f'Bad timestamp in dispatch message: {e}')

        job_id = job_dispatch['job_id']
        run_id = job_dispatch['run_id']

        # ----------------------------------------
        # Make sure realm name matches
        if job_dispatch['realm'] != self.realm:
            self.job_done()
            raise LavaError(
                f'Bad dispatch for job {job_id} ({run_id}): Realm mismatch:'
                f' Expected {self.realm} but got {job_dispatch["realm"]}'
            )

        # ----------------------------------------
        # Make sure worker name matches
        if job_dispatch['worker'] != self.worker:
            self.job_done()
            raise LavaError(
                f'Bad dispatch for job {job_id} ({run_id}): Worker mismatch:'
                f' Expected {self.worker} but got {job_dispatch["worker"]}'
            )

        # ----------------------------------------
        # Get the job info from DynamoDB
        try:
            job_spec = get_job_spec(job_id, self.job_table)
        except Exception as e:
            self.job_done()
            raise LavaError(f'Job {job_id} ({run_id}): {e}')

        # ----------------------------------------
        # Make sure job worker is current worker
        if job_spec['worker'] != self.worker:
            self.job_done()
            raise LavaError(
                f'Job {job_id} ({run_id}): Worker mismatch:'
                f' Expected {self.worker} but got {job_spec["worker"]}'
            )

        # ----------------------------------------
        # Make sure the dispatch message hasn't exceeded the max_tries for the job
        receive_count = int(self._dispatch_msg.attributes['ApproximateReceiveCount'])
        if 0 < job_spec['max_tries'] < receive_count:
            self.job_done()
            raise LavaError(f'Job {job_id} ({run_id}): max_tries exceeded')

        # ----------------------------------------
        # Augment job spec and return it.
        job_spec['parameters'].update(job_dispatch.get('parameters', {}))
        job_spec['run_id'] = run_id
        job_spec['ts_dispatch'] = job_dispatch['ts_dispatch']
        job_spec['realm'] = self.realm
        job_spec['globals'].update(job_dispatch.get('globals', {}))

        # Convert lava's private global timestamps back into datetime. These
        # could have been populated by an upstream dispatch job.

        if 'lava' in job_spec['globals']:
            lava_globals = job_spec['globals']['lava']
            for ts in (
                'master_start',
                'master_ustart',
                'parent_start',
                'parent_ustart',
            ):
                if ts in lava_globals:
                    try:
                        lava_globals[ts] = dateutil.parser.isoparse(lava_globals[ts])
                    except ValueError:
                        raise LavaError(f'{ts}: Bad timestamp: {lava_globals[ts]}')

        return job_spec

    # --------------------------------------------------------------------------
    def job_done(self) -> None:
        """Mark the current job as done."""

        if not self._dispatch_msg:
            LOG.critical(
                'Internal error: Attempted to complete non-existent job',
                extra={'event_type': 'worker'},
            )
            raise Exception('Internal error: Attempted to complete non-existent job')

        self._dispatch_msg.delete()
        self.job_queue.task_done()
        self._dispatch_msg = None

    # --------------------------------------------------------------------------
    def run(self) -> None:
        """Loop on jobs from the job queue."""

        LOG.debug('Starting')
        ThreadMonitor().register_thread()

        while True:
            try:
                job_spec = self.get_next_job()
            except Exception as e:
                LOG.error(str(e), extra={'event_type': 'worker'})
                continue

            job_id = job_spec['job_id']
            run_id = job_spec['run_id']

            # noinspection PyBroadException
            try:
                run_job(
                    job_spec,
                    self.realm_info,
                    self.tmpdir,
                    cleanup=True,
                    dev_mode=self.dev_mode,
                    aws_session=self.aws_session,
                )
            except Exception as e:
                LOG.error(
                    'Job %s (%s): %s',
                    job_id,
                    run_id,
                    e,
                    extra={'event_type': 'job', 'job_id': job_id, 'run_id': run_id},
                )
            finally:
                self.job_done()


# ------------------------------------------------------------------------------
def heartbeat(
    realm: str, worker: str, job_queue: Queue, sqs_queue, hb_period: int, tmpdir: Path
) -> NoReturn:
    """
    Thread worker to issue heartbeat messages.

    :param realm:       Name of lava realm for this worker.
    :param worker:      Name of this lava worker.
    :param job_queue: An internal (Python) queue containing messages extracted
                        from SQS by the main thread. Each one represents a AWS
                        event that needs to be processed.

    :param sqs_queue:   SQS queue from which events are being loaded.
    :param hb_period:   How often heartbeat messages are issued.
    :param tmpdir:      Worker temp directory.

    :return:            Never.
    """

    LOG.debug('Starting')
    monitor = ThreadMonitor()
    monitor.register_thread()
    if hb_period <= 0:
        raise ValueError(f'Bad heartbeat period: {hb_period}')

    wake_time = time.monotonic()

    sqs_q_arn = sqs_queue.attributes['QueueArn']
    deferred_tasks = Defer.on_event(DEFER_ON_EXIT)
    heartbeat_file = tmpdir / config('HEARTBEAT_FILE') if config('HEARTBEAT_FILE') else None

    while True:
        wake_time += hb_period
        try:
            sqs_queue.load()
        except Exception as e:
            LOG.warning(f'Could not get attributes on {sqs_q_arn}: {e}', {'event_type': 'worker'})
            sqs_q_len = -1
            sqs_q_nvis = -1
        else:
            sqs_q_len = int(sqs_queue.attributes['ApproximateNumberOfMessages'])
            sqs_q_nvis = int(sqs_queue.attributes['ApproximateNumberOfMessagesNotVisible'])

        # This is basically a keep-alive on our tmp dir so some O/S cleanup
        # process doesn't blow it away due to inactivity. If, for some reason
        # the directory is blown away, this will cause the heartbeat thread
        # to die which will cause a heartbeat alarm. This is intended as
        # the worker is stuffed at this point.

        if heartbeat_file:
            heartbeat_file.touch()

        LOG.info(
            # Legacy format in the message.
            f'heartbeat realm={realm} worker={worker}',
            extra={
                'event_type': 'heartbeat',
                'sqs': {'messages': sqs_q_len, 'notvisible': sqs_q_nvis},
                'internal': {'qlen': job_queue.qsize()},
                'threads': monitor.thread_status,
                'deferred_tasks': len(deferred_tasks.tasks),
            },
        )

        try:
            time.sleep(wake_time - time.monotonic())
        except ValueError as e:
            LOG.warning(str(e), {'event_type': 'worker'})


# ------------------------------------------------------------------------------
def put_metrics(
    realm: str, worker: str, tmpdir: str, period: int = 60, aws_session: boto3.Session = None
) -> NoReturn:
    """
    Thread worker to report internal system stats to CloudWatch metrics.

    :param realm:       Name of lava realm for this worker.
    :param worker:      Name of this lava worker.
    :param tmpdir:      Thw working directory used by the worker.
    :param period:      Reporting period in seconds. Default 60 seconds.
    :param aws_session: A boto3 Session(). If not specified, a default session
                        is created.

    """

    LOG.debug('Starting')
    monitor = ThreadMonitor()
    monitor.register_thread()
    if period <= 0:
        raise ValueError(f'Bad status period: {period}')

    # Set up for disk usage metric on the local docker vol
    try:
        docker_dir = docker.client.from_env().info()['DockerRootDir']
    except Exception as e:
        LOG.warning(f'Cannot determine docker volume: {e}', extra={'event_type': 'worker'})
        docker_dir = None

    if docker_dir and not os.path.exists(docker_dir):
        LOG.warning(
            f'Docker volume {docker_dir} does not exist - may not be local',
            extra={'event_type': 'worker'},
        )
        docker_dir = None

    if not aws_session:
        aws_session = boto3.Session()

    cw = aws_session.client('cloudwatch')
    namespace = config('CW_NAMESPACE')

    # noinspection PyBroadException
    try:
        node_id = ec2_instance_id()
        LOG.debug('Looks like we\'re in ec2')
    except Exception:
        LOG.debug('Looks like we\'re not in ec2')
        node_id = socket.gethostname()
    LOG.info(f'Node ID is {node_id}', extra={'event_type': 'worker'})

    dimensions = [
        {'Name': 'Realm', 'Value': realm},
        {'Name': 'Worker', 'Value': worker},
        {'Name': 'Instance', 'Value': node_id if node_id else socket.gethostname()},
    ]

    time.sleep(15)  # Offset from the heartbeat

    wake_time = time.monotonic()
    vm_total = psutil.virtual_memory().total

    while True:
        wake_time += period
        r = resource.getrusage(resource.RUSAGE_SELF)

        try:
            ts = now_tz()
            workers_alive, workers_dead = monitor.threadcount(WORKER_THREAD_NAME + '*')
            metric_data = [
                {
                    'MetricName': 'WorkerThreadsAlive',
                    'Value': workers_alive,
                    'Unit': 'Count',
                    'Dimensions': dimensions,
                    'Timestamp': ts,
                },
                {
                    'MetricName': 'WorkerThreadsDead',
                    'Value': workers_dead,
                    'Unit': 'Count',
                    'Dimensions': dimensions,
                    'Timestamp': ts,
                },
                {
                    'MetricName': 'MaxRss',
                    'Value': r.ru_maxrss,
                    'Unit': 'Bytes',
                    'Dimensions': dimensions,
                    'Timestamp': ts,
                },
                # This one is legacy -- will be removed
                {
                    'MetricName': 'PercentDiskUsed',
                    'Value': psutil.disk_usage(tmpdir).percent,
                    'Unit': 'Percent',
                    'Dimensions': dimensions,
                    'Timestamp': ts,
                },
                # Replacement for old PercentDiskUsed
                {
                    'MetricName': 'PercentTmpDiskUsed',
                    'Value': psutil.disk_usage(tmpdir).percent,
                    'Unit': 'Percent',
                    'Dimensions': dimensions,
                    'Timestamp': ts,
                },
                {
                    'MetricName': 'PercentMemUsed',
                    'Value': (vm_total - psutil.virtual_memory().available) / vm_total * 100.0,
                    'Unit': 'Percent',
                    'Dimensions': dimensions,
                    'Timestamp': ts,
                },
                {
                    'MetricName': 'PercentSwapUsed',
                    'Value': psutil.swap_memory().percent,
                    'Unit': 'Percent',
                    'Dimensions': dimensions,
                    'Timestamp': ts,
                },
            ]

            if docker_dir:
                metric_data.append(
                    {
                        'MetricName': 'PercentDockerDiskUsed',
                        'Value': psutil.disk_usage(docker_dir).percent,
                        'Unit': 'Percent',
                        'Dimensions': dimensions,
                        'Timestamp': ts,
                    },
                )
            cw.put_metric_data(Namespace=namespace, MetricData=metric_data)
            LOG.debug('Pushed node metrics to CloudWatch')
        except Exception as e:
            LOG.warning(
                f'Could not put CloudWatch metric data: {e}', extra={'event_type': 'worker'}
            )

        try:
            time.sleep(wake_time - time.monotonic())
        except ValueError as e:
            LOG.warning(str(e), extra={'event_type': 'worker'})


# ..............................................................................
# endregion threadworkers
# ..............................................................................


# ------------------------------------------------------------------------------
if __name__ == '__main__':
    # Uncomment for debugging
    # exit(main())  # noqa: ERA001
    try:
        exit(main())
    except Exception as ex:
        LOG.error('%s', ex, extra={'event_type': 'worker'})
        exit(1)
    except KeyboardInterrupt:
        LOG.warning('Interrupt', extra={'event_type': 'worker'})
        exit(2)
