#!/usr/bin/env python3

"""Run a bunch of health checks on a lava realm."""

from __future__ import annotations

import argparse
import logging
import os
import sys
from collections import namedtuple
from fnmatch import fnmatch
from functools import lru_cache
from typing import Any, Callable

import boto3
from dateutil.tz import UTC
from jinja2 import Template
from jinja2.exceptions import TemplateError, TemplateRuntimeError

from lava.lavacore import LOGNAME, get_realm_info, jinja_render_vars
from lava.lib.aws import dynamo_scan_table
from lava.lib.datetime import now_tz
from lava.lib.logging import setup_logging
from lava.lib.misc import TrackedMapping, format_dict_unescaped
from lava.version import __version__

__author__ = 'Murray Andrews'

PROG = os.path.splitext(os.path.basename(sys.argv[0]))[0]
LOG = logging.getLogger(LOGNAME)
LOGLEVEL = 'info'

# The table spec field to suppress checks. The contents must be either a string
# containing the name of a single check type or a list of strings to suppress
# multiple checks.
NO_CHECK = 'x-lava-nocheck'

LavaChecker = namedtuple('LavaChecker', field_names=('func', 'description'))
BadItem = namedtuple('BadItem', field_names=('id', 'owner', 'details'))

CHECKERS = {}
OUTPUT_TEMPLATE = """
{%- set _owner = items | map(attribute='owner') | reject('none') | list -%}
{%- set _ = _owner.append('Owner') -%}
{%- set owner_width = _owner | map('length') | max -%}
{%- set _details = items | map(attribute='details') | reject('none') | list -%}
{%- set _ = _details.append('Details') -%}
{%- set details_width = _details | map('length') | max -%}
{%- set id_width = items | map(attribute='id') | map('length') | max %}
# {{ check_type }}: {{ description | trim('.') }}

|{{ 'ID' | center(id_width+2) -}}
    |{{ 'Owner' | center(owner_width + 2) -}}
    |{{ 'Details' | center(details_width + 2) -}}
    |
|{{ '-' * (id_width+2) -}}
    |{{ '-' * (owner_width+2) -}}
    |{{ '-' * (details_width+2) -}}
    |
{% for item in items -%}
| {{ '{:{}}'.format(item.id, id_width+1) -}}
    | {{ '{:{}}'.format(item.owner or '', owner_width+1) -}}
    | {{ '{:{}}'.format(item.details or '', details_width+1) -}}
    |
{% endfor -%}
"""


# ------------------------------------------------------------------------------
class Sponge(dict):
    """
    Class to accept any form of access without complaint.

    Inheriting from dict is a kludge to allow JSON serialisability as sometimes
    this can occur in a Jinja template.

    :param default_factory: If specified, this callable is used to create values.
                            If not specified, the sponge itself is used as the
                            the response to every access.
    """

    # noinspection PyMissingConstructor
    def __init__(self, default_factory: Callable = None):
        """Create a sponge to absorb item / attribute access requests."""
        self._default_factory = default_factory

    def __getitem__(self, item):
        """Get an item."""
        if self._default_factory:
            return self._default_factory()
        return self

    def __getattr__(self, item):
        """Get an attribute."""
        if self._default_factory:
            return self._default_factory()
        return self

    def __len__(self) -> int:
        """Get a dummy length."""
        return 0

    def __call__(self, *args, **kwargs):
        """Call the sponge and return itself."""
        return self


# ------------------------------------------------------------------------------
def lava_check(func: Callable):
    """
    Register health checkers.

    The first line of the decorated function's docstring is used to describe
    the check and the function name is the check label.

    Usage:

    .. code:: python

        @lava_check('check_type')
        a_func(...)

    :param func:    Function to decorate.
    """

    def wrapped(*args, **kwargs):
        """Register the handler function."""
        return func(*args, check_type=func.__name__, **kwargs)

    if func.__name__ in CHECKERS:
        raise Exception(f'{func.__name__} is already registered')
    CHECKERS[func.__name__] = LavaChecker(wrapped, func.__doc__.splitlines()[1].strip())

    return wrapped


# ------------------------------------------------------------------------------
def nocheck(check_type: str, spec: dict[str, Any]) -> bool:
    """
    Determine if a check type is suppressed for the given Dynamo specification entry.

    :param check_type:  The check type.
    :param spec:        The specification entry (job, connection etc.)

    :return:            True if the check is suppressed.
    """

    suppress_field = spec.get(NO_CHECK)
    if not suppress_field:
        return False

    if isinstance(suppress_field, str):
        return check_type == suppress_field

    if not isinstance(suppress_field, list):
        LOG.warning('Bad %s field in spec: %s', NO_CHECK, spec)

    return check_type in suppress_field


# ------------------------------------------------------------------------------
def print_checkers() -> None:
    """Print out check descriptions."""

    n = max(len(chk) for chk in CHECKERS)

    print('\nAvailable checkers are:\n')
    for check_type, checker in sorted(CHECKERS.items()):
        print(f'{check_type + ":":{n + 1}} {checker.description}')
    print()


# ------------------------------------------------------------------------------
@lru_cache(10)
def load_table(table_name: str, key: str, aws_session: boto3.Session) -> dict[str, dict[str, Any]]:
    """
    Read the entire contents of a table into memory and cache it.

    Be careful with this in terms of memory use and DynamoDB full table scans.

    :param table_name:      Name of the table.
    :param key:             The table hash key.
    :param aws_session:     A boto3 Session().

    :return:                A dict of table items keyed on the table key.
    """

    LOG.debug('Loading table: %s', table_name)
    return {item[key]: item for item in dynamo_scan_table(table_name, aws_session)}


# ..............................................................................
# region job checkers
# ------------------------------------------------------------------------------


# ------------------------------------------------------------------------------
@lava_check
def jobrsu(
    realm: str, aws_session: boto3.Session, suppression=True, check_type=None
) -> list[BadItem]:
    """
    Redshift_unload jobs with insecure set to true.

    :param realm:           Lava realm.
    :param aws_session:     A boto3 Session.
    :param suppression:     If True, allow checks to be suppressed by a marker
                            in the DynamoDB table entry.
    :param check_type:      This is injected by the decorator.

    :return:                A list of tuples (offending item ID, comment).
    """

    LOG.debug('Running %s', check_type)

    jobs = load_table(f'lava.{realm}.jobs', 'job_id', aws_session)
    bad_jobs = []

    for job_id, job_spec in jobs.items():
        if suppression and nocheck(check_type, job_spec):
            LOG.debug('Check %s suppressed for %s', check_type, job_id)
            continue
        if job_spec.get('type') == 'redshift_unload' and job_spec.get('parameters', {}).get(
            'insecure'
        ):
            bad_jobs.append(
                BadItem(
                    id=job_id,
                    owner=job_spec.get('owner'),
                    details=None,
                )
            )

    return bad_jobs


# ------------------------------------------------------------------------------
# noinspection PyUnusedLocal
@lava_check
def jobmeta(
    realm: str, aws_session: boto3.Session, suppression=True, check_type=None
) -> list[BadItem]:
    """
    Job specs with missing metadata (e.g. description, owner).

    :param realm:           Lava realm.
    :param aws_session:     A boto3 Session.
    :param suppression:     Not used. Check cannot be suppressed. Captain's call.
    :param check_type:      This is injected by the decorator.

    :return:                A list of tuples (offending item ID, comment).
    """

    LOG.debug('Running %s', check_type)

    jobs = load_table(f'lava.{realm}.jobs', 'job_id', aws_session)
    bad_jobs = []

    for job_id, job in jobs.items():
        missing = [field for field in ('description', 'owner') if not job.get(field)]
        if missing:
            bad_jobs.append(
                BadItem(
                    id=job_id,
                    owner=job.get('owner'),
                    details=f'Missing: {", ".join(missing)}',
                )
            )

    return bad_jobs


# ------------------------------------------------------------------------------
@lava_check
def jobrepo(
    realm: str, aws_session: boto3.Session, suppression=True, check_type=None
) -> list[BadItem]:
    """
    Job specs that don't appear to have an associated repo (x-lava-git-repo).

    :param realm:           Lava realm.
    :param aws_session:     A boto3 Session.
    :param suppression:     If True, allow checks to be suppressed by a marker
                            in the DynamoDB table entry.
    :param check_type:      This is injected by the decorator.
    """

    LOG.debug('Running %s', check_type)

    jobs = load_table(f'lava.{realm}.jobs', 'job_id', aws_session)
    bad_jobs = []

    for job_id, job in jobs.items():
        if suppression and nocheck(check_type, job):
            LOG.debug('Check %s suppressed for %s', check_type, job_id)
            continue
        if not job.get('x-lava-git-repo'):
            bad_jobs.append(BadItem(id=job_id, owner=job.get('owner'), details=None))

    return bad_jobs


# ------------------------------------------------------------------------------
@lava_check
def joborphan(
    realm: str, aws_session: boto3.Session, suppression=True, check_type=None
) -> list[BadItem]:
    """
    Jobs with no recorded run events.

    :param realm:           Lava realm.
    :param aws_session:     A boto3 Session.
    :param suppression:     If True, allow checks to be suppressed by a marker
                            in the DynamoDB table entry.
    :param check_type:      This is injected by the decorator.
    """

    LOG.debug('Running %s', check_type)

    jobs = load_table(f'lava.{realm}.jobs', 'job_id', aws_session)
    bad_jobs = []
    events_table = aws_session.resource('dynamodb').Table(f'lava.{realm}.events')

    capacity_consumed = 0.0
    for job_id, job in jobs.items():
        if suppression and nocheck(check_type, job):
            LOG.debug('Check %s suppressed for %s', check_type, job_id)
            continue
        LOG.debug('Checking events for %s', job_id)
        response = events_table.query(
            Select='COUNT',
            Limit=1,
            ReturnConsumedCapacity='TOTAL',
            KeyConditionExpression='#job_id = :job_id',
            ExpressionAttributeNames={'#job_id': 'job_id'},
            ExpressionAttributeValues={':job_id': job_id},
        )
        capacity_consumed += response['ConsumedCapacity']['CapacityUnits']
        if response['Count'] == 0:
            bad_jobs.append(BadItem(id=job_id, owner=job.get('owner'), details=None))

    LOG.debug('Events table capacity consumed: %.1f', capacity_consumed)
    return bad_jobs


# ------------------------------------------------------------------------------
# noinspection PyUnusedLocal
@lava_check
def jobjinja(
    realm: str, aws_session: boto3.Session, suppression=True, check_type=None
) -> list[BadItem]:
    """
    Job specs with Jinja rendering issues.

    To do this we do a trial Jinja rendering of the entire object (which is not
    what the worker does but close enough). Note that this is a simulated
    approximation of what the actual worker does at runtime. The globals
    rendering component is replaced with a proxy object that tracks attempts to
    access unknown items and some other components are replaced with "sponges"
    that allow any form of attribute or item access.

    :param realm:           Lava realm.
    :param aws_session:     A boto3 Session.
    :param suppression:     If True, allow checks to be suppressed by a marker
                            in the DynamoDB table entry.
    :param check_type:      This is injected by the decorator.

    """

    LOG.debug('Running %s', check_type)

    realms_table = aws_session.resource('dynamodb').Table('lava.realms')
    realm_info = get_realm_info(realm, realms_table)
    jobs = load_table(f'lava.{realm}.jobs', 'job_id', aws_session)
    bad_jobs = []

    ts_start = now_tz()
    ts_ustart = ts_start.astimezone(UTC)

    for job_id, job_spec in jobs.items():
        if suppression and nocheck(check_type, job_spec):
            LOG.debug('Check %s suppressed for %s', check_type, job_id)
            continue
        # We need to augment the job spec like the lava worker would do. The
        # content of the extra fields doesn't matter but they need to be present
        # and of the right type.
        job_spec: dict[str, Any] = {
            'globals': {},
            'parameters': {},
            'state': {},
            'vars': {},
            **job_spec,
            'run_id': '',
            'ts_start': ts_start,
            'ts_ustart': ts_ustart,
            'ts_dispatch': ts_start,
        }
        if not isinstance(job_spec['parameters'], dict):
            bad_jobs.append(
                BadItem(
                    id=job_id, owner=job_spec.get('owner'), details='Parameters must be an object'
                )
            )
            continue

        # State items can be structured -- fake them with sponges
        for k, v in job_spec['state'].items():
            if v == {}:
                job_spec['state'][k] = Sponge(lambda: 'dummy-state-value')

        job_spec['parameters'].setdefault('vars', {})
        lava_globals = job_spec['globals'].setdefault('lava', {})
        lava_globals.setdefault('master_job_id', job_spec['job_id'])
        lava_globals.setdefault('parent_job_id', job_spec['job_id'])
        lava_globals.setdefault('master_start', ts_start)
        lava_globals.setdefault('master_ustart', ts_ustart)
        lava_globals.setdefault('parent_start', ts_start)
        lava_globals.setdefault('parent_ustart', ts_ustart)
        job_spec_globals = TrackedMapping(job_spec.get('globals', {}), default_factory=str)
        job_spec['globals'] = job_spec_globals
        render_vars = jinja_render_vars(
            job_spec,
            realm_info,
            # Here we are being lenient with an edge case where parameters can be
            # set but set to None so get('parameters', {}).get(...) won't work.
            vars=(job_spec.get('parameters') or {}).get('vars', {}),
            result=Sponge(),
        )

        # This is a bit hacky but it does actually work for vast majority of cases.
        template_txt = format_dict_unescaped(job_spec)
        try:
            _ = Template(template_txt).render(**render_vars)
        except TemplateRuntimeError as e:
            bad_jobs.append(
                BadItem(
                    id=job_id, owner=job_spec.get('owner'), details=f'Template runtime error: {e}'
                )
            )
            continue
        except TemplateError as e:
            bad_jobs.append(
                BadItem(id=job_id, owner=job_spec.get('owner'), details=f'Template error: {e}')
            )
            continue

        except Exception as e:
            bad_jobs.append(BadItem(id=job_id, owner=job_spec.get('owner'), details=str(e)))
        finally:
            if job_spec_globals.unknown_refs:
                bad_jobs.append(
                    BadItem(
                        id=job_id,
                        owner=job_spec.get('owner'),
                        details=f'Undeclared globals: {", ".join(job_spec_globals.unknown_refs)}',
                    )
                )

    return bad_jobs


# ..............................................................................
# endregion job checkers
# ------------------------------------------------------------------------------


# ..............................................................................
# region connection checkers
# ------------------------------------------------------------------------------


# ------------------------------------------------------------------------------
# noinspection PyUnusedLocal
@lava_check
def conmeta(
    realm: str, aws_session: boto3.Session, suppression=True, check_type=None
) -> list[BadItem]:
    """
    Find connection specs with missing metadata (e.g. description, owner).

    :param realm:           Lava realm.
    :param aws_session:     A boto3 Session.
    :param suppression:     Not used. Check cannot be suppressed. Captain's call.
    :param check_type:      This is injected by the decorator.

    :return:                A list of tuples (offending item ID, comment).
    """

    LOG.debug('Running %s', check_type)

    connections = load_table(f'lava.{realm}.connections', 'conn_id', aws_session)
    bad_conns = []

    for conn_id, conn in connections.items():
        missing = [field for field in ('description', 'owner') if not conn.get(field)]
        if missing:
            bad_conns.append(
                BadItem(
                    id=conn_id,
                    owner=conn.get('owner'),
                    details=f'Missing: {", ".join(missing)}',
                )
            )

    return bad_conns


# ..............................................................................
# endregion connection checkers
# ------------------------------------------------------------------------------


# ..............................................................................
# region s3trigger checkers
# ------------------------------------------------------------------------------


# ------------------------------------------------------------------------------
# noinspection PyUnusedLocal
@lava_check
def trigmeta(
    realm: str, aws_session: boto3.Session, suppression=True, check_type=None
) -> list[BadItem]:
    """
    S3trigger specs with missing metadata (e.g. description, owner).

    :param realm:           Lava realm.
    :param aws_session:     A boto3 Session.
    :param suppression:     Not used. Check cannot be suppressed. Captain's call.
    :param check_type:      This is injected by the decorator.

    :return:                A list of tuples (offending item ID, comment).
    """

    LOG.debug('Running %s', check_type)

    specs = load_table(f'lava.{realm}.s3triggers', 'trigger_id', aws_session)
    bad_specs = []

    for trig_id, spec in specs.items():
        missing = [field for field in ('description', 'owner') if not spec.get(field)]
        if missing:
            bad_specs.append(
                BadItem(
                    trig_id,
                    owner=spec.get('owwner'),
                    details=f'Missing: {", ".join(missing)}',
                )
            )

    return bad_specs


# ------------------------------------------------------------------------------
# noinspection PyUnusedLocal
@lava_check
def trigslash(
    realm: str, aws_session: boto3.Session, suppression=True, check_type=None
) -> list[BadItem]:
    """
    Find s3trigger specs with prefix ending in slash.

    :param realm:           Lava realm.
    :param aws_session:     A boto3 Session.
    :param suppression:     Not used. Check cannot be suppressed as these
                            trigger specs just won't work.
    :param check_type:      This is injected by the decorator.

    :return:                A list of tuples (offending item ID, comment).
    """

    LOG.debug('Running %s', check_type)

    specs = load_table(f'lava.{realm}.s3triggers', 'trigger_id', aws_session)
    return [
        BadItem(
            trig_id,
            owner=spec.get('owwner'),
            details='Prefix ending with / won\'t work',
        )
        for trig_id, spec in specs.items()
        if spec.get('prefix', '').endswith('/')
    ]


# ..............................................................................
# endregion s3trigger checkers
# ------------------------------------------------------------------------------


# ------------------------------------------------------------------------------
def process_cli_args() -> argparse.Namespace:
    """
    Process the command line arguments.

    :return:    The args namespace.
    """

    argp = argparse.ArgumentParser(prog=PROG, description='Check lava specifications for problems.')

    argp.add_argument(
        '-c',
        '--check',
        action='append',
        metavar='GLOB',
        help=(
            'Run the health checks with names matching the given glob patterns.'
            ' Can be used multiple times.'
            ' If not specified, print a list of available checks.'
        ),
    )

    argp.add_argument('--profile', action='store', help='As for AWS CLI.')

    argp.add_argument(
        '-r',
        '--realm',
        action='store',
        help=(
            'Lava realm name. If not specified, the environment variable LAVA_REALM must be set.'
        ),
    )

    argp.add_argument(
        '-S',
        '--no-suppress',
        dest='suppression',
        action='store_false',
        help=(
            'Disable suppression of checks for specific DynamoDB entries via the'
            f' {NO_CHECK} field. By default suppression of specific checks is'
            f' permitted for some check types.'
        ),
    )

    argp.add_argument('-v', '--version', action='version', version=__version__)

    # ------------------------------
    # Logging options

    logp = argp.add_argument_group('logging arguments')
    logp.add_argument(
        '--no-colour',
        '--no-color',
        dest='colour',
        action='store_false',
        default=True,
        help='Don\'t use colour in information messages.',
    )

    logp.add_argument(
        '-l',
        '--level',
        metavar='LEVEL',
        default=LOGLEVEL,
        help=(
            'Print messages of a given severity level or above. The standard logging level names'
            ' are available but debug, info, warning and error are most useful.'
            f' The Default is {LOGLEVEL}.'
        ),
    )

    logp.add_argument(
        '--log',
        action='store',
        help='Log to the specified target. This can be either a file'
        ' name or a syslog facility with an @ prefix (e.g. @local0).',
    )

    logp.add_argument(
        '--tag',
        action='store',
        default=PROG,
        help=f'Tag log entries with the specified value. The default is {PROG}.',
    )

    return argp.parse_args()


# ------------------------------------------------------------------------------
def main() -> int:
    """
    Show time.

    :return:    status
    """

    setup_logging(LOGLEVEL, name=LOGNAME, prefix=PROG)
    args = process_cli_args()
    setup_logging(args.level, name=LOGNAME, target=args.log, colour=args.colour, prefix=args.tag)

    if not args.check:
        print_checkers()
        return 0

    # Work out which checks we have to run. Take care to avoid duplicates.
    checks = set()
    bad_checks = set()
    for check_pattern in args.check:
        c = {check for check in CHECKERS if fnmatch(check, check_pattern)}
        if not c:
            bad_checks.add(check_pattern)
        else:
            checks |= c

    if bad_checks:
        raise Exception(f'No matching check types: {", ".join(sorted(bad_checks))}')

    if not args.realm:
        raise Exception('Lava realm must be specified')

    aws_session = boto3.Session()

    for check_type in sorted(checks):
        checker = CHECKERS[check_type]
        bad_items = checker.func(args.realm, aws_session, suppression=args.suppression)

        if bad_items:
            print(
                Template(OUTPUT_TEMPLATE).render(
                    items=sorted(bad_items, key=lambda x: x.id),
                    check_type=check_type,
                    description=checker.description,
                )
            )

    return 0


# ------------------------------------------------------------------------------
if __name__ == '__main__':
    # Uncomment for debugging
    # exit(main())  # noqa: ERA001
    try:
        exit(main())
    except Exception as ex:
        print(f'{PROG}: {ex}', file=sys.stderr)
        exit(1)
    except KeyboardInterrupt:
        print('Interrupt', file=sys.stderr)
        exit(2)
