#!/usr/bin/env python3

"""Generate a task message for a worker queue."""

import argparse
import logging
import os
import platform
import sys

import boto3

from lava.config import config
from lava.lavacore import dispatch
from lava.lib.argparse import StoreNameValuePair
from lava.lib.datetime import duration_to_seconds
from lava.lib.logging import JsonFormatter, setup_logging
from lava.lib.misc import dict_expand_keys
from lava.version import __version__

__author__ = 'Murray Andrews'

PROG = os.path.splitext(os.path.basename(sys.argv[0]))[0]
LOGNAME = 'lava'  # Set to None to include boto which uses root logger.
LOGLEVEL = 'info'
LOG = logging.getLogger(name=LOGNAME)

# Fields emitted when JSON logging is enabled
JSON_LOG_FIELDS = {
    'timestamp': 'asctime',
    'level': 'levelname',
    'message': 'message',
    'pid': 'process',
}


# ------------------------------------------------------------------------------
def process_cli_args() -> argparse.Namespace:
    """
    Process the command line arguments.

    :return:    The args namespace.
    """

    argp = argparse.ArgumentParser(prog=PROG, description='Lava job dispatcher.')

    # ------------------------------
    # General options

    argp.add_argument('--profile', action='store', help='As for AWS CLI.')

    argp.add_argument('-v', '--version', action='version', version=__version__)

    argp.add_argument(
        '--check-dispatch',
        action='store_true',
        help=(
            'If specified, check for the the existence of a dispatch suppression'
            f' file "{config("NO_DISPATCH")}". If the file is present, all dispatches'
            ' are suppressed. This is typically only used for scheduled dispatches'
            ' when a dispatcher node is in the process of shutting down.'
        ),
    )

    # ------------------------------
    # Dispatch control options

    disp = argp.add_argument_group('dispatch control options')

    disp.add_argument(
        '-d',
        '--delay',
        action='store',
        default=0,
        help=(
            'Delay dispatch by the specified duration. Default is 0.'
            f' Maximum is {config("SQS_MAX_DELAY_MINS", int)} minutes.'
        ),
    )

    disp.add_argument(
        '-q',
        '--queue',
        action='store',
        help=(
            'AWS SQS queue name. If not specified, the queue name'
            ' is derived from the realm and worker name.'
        ),
    )

    disp.add_argument(
        '-r',
        '--realm',
        default=os.environ.get('LAVA_REALM'),
        action='store',
        help=(
            'Lava realm name. Defaults to the value of the LAVA REALM environment'
            ' variable. A value must be specified by one of these mechnisms.'
        ),
    )

    disp.add_argument(
        '-w',
        '--worker',
        action='store',
        help=(
            'Lava worker name. The worker must be a member of the specified realm. If specified,'
            ' the worker name must match the value in the job specification. If not specified,'
            ' the correct value will be looked up in the jobs table.'
        ),
    )

    # ------------------------------
    # Job options

    jobp = argp.add_argument_group('job options')

    jobp.add_argument(
        '-g',
        '--global',
        dest='globals',
        action=StoreNameValuePair,
        metavar='name=VALUE',
        help=(
            'Additional global attribute to include in the job dispatch event. This option can be'
            ' used multiple times. If global names contain dots, they will be converted into a'
            ' hierachy using the dots as level separators.'
        ),
    )

    jobp.add_argument(
        '-p',
        '--param',
        action=StoreNameValuePair,
        metavar='name=VALUE',
        help=(
            'Additional parameter to include in the job dispatch event. This option can be used'
            ' multiple times. If parameter names contain dots, they will be converted into a'
            ' hierarchy using the dots as level separators.'
        ),
    )
    jobp.add_argument(
        'job_id',
        metavar='job-id',
        nargs='+',
        help='One or more job IDs for the specified realm.',
    )

    # ------------------------------
    # Logging options

    logp = argp.add_argument_group('logging arguments')
    logp.add_argument(
        '-c',
        '--no-colour',
        '--no-color',
        dest='colour',
        action='store_false',
        default=True,
        help='Don\'t use colour in information messages.',
    )

    logp.add_argument(
        '-l',
        '--level',
        metavar='LEVEL',
        default=LOGLEVEL,
        help=(
            'Print messages of a given severity level or above. The standard logging level names'
            ' are available but debug, info, warning and error are most useful.'
            f' The Default is {LOGLEVEL}.'
        ),
    )

    logp.add_argument(
        '--log-json',
        action='store_true',
        help=(
            'Log messages in JSON format. This is particularly useful when log'
            ' messages end up in CloudWatch logs as it simplifies searching.'
        ),
    )

    logp.add_argument(
        '--log',
        action='store',
        help='Log to the specified target. This can be either a file'
        ' name or a syslog facility with an @ prefix (e.g. @local0).',
    )

    logp.add_argument(
        '--tag',
        action='store',
        default=PROG,
        help=f'Tag log entries with the specified value. The default is {PROG}.',
    )

    # ------------------------------
    args = argp.parse_args()

    if not args.realm:
        argp.error('Lava realm must be specified')
    # Look for any params with names like x.y and turn this into sub-dict x[y]
    if args.param:
        try:
            args.param = dict_expand_keys(args.param)
        except ValueError as e:
            argp.error(f'Bad -p, --param args: {e}')

    # Look for any globals with names like x.y and turn this into sub-dict x[y]
    if args.globals:
        try:
            args.globals = dict_expand_keys(args.globals)
        except ValueError as e:
            argp.error(f'Bad -g, --global args: {e}')

    try:
        d = round(duration_to_seconds(args.delay))
        if not 0 <= d <= 60 * config('SQS_MAX_DELAY_MINS', int):
            raise ValueError(f'Must be between 0 and {config("SQS_MAX_DELAY_MINS", int)} minutes')
        args.delay = d
    except ValueError as e:
        argp.error(f'Bad -d, --delay value: {e}')

    return args


# ---------------------------------------------------------------------------------------
def main() -> int:
    """
    Do the business.

    :return:    status
    """

    setup_logging(LOGLEVEL, name=LOGNAME, prefix=PROG)
    args = process_cli_args()
    setup_logging(
        args.level,
        name=LOGNAME,
        target=args.log,
        colour=args.colour,
        prefix=args.tag,
        formatter=(
            JsonFormatter(
                fields=JSON_LOG_FIELDS,
                extra={
                    'event_source': PROG,
                    'realm': args.realm,
                    'host': platform.node(),
                    'tag': args.tag,
                },
                # Important for rsyslog to have syslogtag prefix or it wrecks the JSON.
                tag=args.tag if args.tag and args.log and args.log.startswith('@') else None,
            )
            if args.log_json
            else None
        ),
    )

    # Check for a dispatch inhibitor.
    if args.check_dispatch and os.path.exists(config('NO_DISPATCH')):
        LOG.warning(
            f'Dispatch suppressed by presence of {config("NO_DISPATCH")}',
            extra={'event_type': 'dispatch'},
        )
        return 0

    # ----------------------------------------
    errors = 0

    aws_session = boto3.Session(profile_name=args.profile)

    for job_id in args.job_id:
        try:
            run_id = dispatch(
                realm=args.realm,
                job_id=job_id,
                worker=args.worker,
                params=args.param,
                globals_=args.globals,
                delay=args.delay,
                queue_name=args.queue,
                aws_session=aws_session,
            )
        except Exception as e:
            errors = 1
            LOG.error(
                f'{job_id}@{args.realm} dispatch failed - {e}',
                extra={'event_type': 'dispatch', 'job_id': job_id},
            )
        else:
            LOG.info(
                f'{job_id}@{args.realm} dispatched ({run_id})',
                extra={'event_type': 'dispatch', 'job_id': job_id, 'run_id': run_id},
            )

    return errors


# ------------------------------------------------------------------------------
if __name__ == '__main__':
    # Uncomment for debugging
    # exit(main())  # noqa: ERA001
    try:
        exit(main())
    except Exception as ex:
        LOG.error('%s', ex)
        exit(1)
    except KeyboardInterrupt:
        LOG.warning('Interrupt')
        exit(2)
