#!/usr/bin/env python3

"""Command line utility to set and validate checksums on lava DynamdoDB entries."""

from __future__ import annotations

import argparse
import hashlib
import json
import os
import sys
from collections import namedtuple
from datetime import datetime, timezone
from enum import IntEnum
from typing import Any
from zipfile import ZipFile

import boto3
import jinja2

from lava.lib.aws import dynamo_scan_table
from lava.lib.fileops import sanitise_filename
from lava.lib.misc import DictChecksum, match_none
from lava.version import __version__

__author__ = 'Murray Andrews'

PROG = os.path.splitext(os.path.basename(sys.argv[0]))[0]
CHECKSUM_KEY_NAME = 'x-lava-chk'
IGNORE_KEYS_GLOB = '[xX]-*'
HASH_ALGORITHM = 'sha256'
TAG_DEFAULT = 'LC'

OUTPUT_FORMATS = {
    # Default format when stdout is not a tty
    'txt': """
{%- for r in results -%}
{{ r.id }}: {{ r.severity.name }} : {{ r.description }}
{% endfor %}""",
    # Like text but with colour when stdout is a tty
    'tty': """
{%- set colour={
    'INFO': '\033[32m', 'WARNING': '\033[33m', 'NOTICE': '\033[34m', 'ERROR': '\033[31m'
} -%}
{%- set RESET = '\033[0m' %}
{%- for r in results -%}
{{ colour[r.severity.name] }}{{ r.id }}: {{ r.description }}{{ RESET }}
{% endfor %}""",
    # Note we deliberately have bare table HTML only
    'html': """
{%- autoescape true -%}
    <TABLE class="lava-checksum-report">
        <THEAD>
            <TR>
                <TH>{{ key_field }}</TH>
                <TH>Severity</TH>
                <TH>Description</TH>
            </TR>
        </THEAD>
        <TBODY>
        {%- for r in results %}
            <TR class="result-{{ r.severity.name | lower }}">
                <TD>{{ r.id }}</TD>
                <TD>{{ r.severity.name }}</TD>
                <TD>{{ r.description }}</TD>
            </TR>
        {% endfor %}
        </TBODY>
    </TABLE>
{%- endautoescape %}
""",
    'md': """|{{ key_field }}|Severity|Description|
|--|--|--|
{% for r in results -%}
|{{ r.id }}|{{ r.severity.name }}|{{ r.description }}|
{% endfor %}""",
}

TableDescriptor = namedtuple('TableDescriptor', ['title', 'key', 'name'])
TABLES = (
    TableDescriptor('jobs', 'job_id', 'lava.{realm}.jobs'),
    TableDescriptor('connections', 'conn_id', 'lava.{realm}.connections'),
    TableDescriptor('s3triggers', 'trigger_id', 'lava.{realm}.s3triggers'),
    TableDescriptor('triggers', 'trigger_id', 'lava.{realm}.s3triggers'),
    TableDescriptor('realms', 'realm', 'lava.realms'),
)

CheckResult = namedtuple('CheckResult', ['id', 'description', 'severity'])


# ------------------------------------------------------------------------------
class ResultSeverity(IntEnum):
    """Severities for checksum results."""

    INFO = 0
    WARNING = 1
    NOTICE = 2
    ERROR = 3


# ------------------------------------------------------------------------------
def process_cli_args() -> argparse.Namespace:
    """
    Process the command line arguments.

    :return:    The args namespace.
    """

    argp = argparse.ArgumentParser(
        prog=PROG,
        description='Set and validate checksums on lava DynamdoDB table entries.',
    )

    argp.add_argument(
        '-f',
        '--format',
        action='store',
        choices=OUTPUT_FORMATS,
        help='Output format. Default is "tty" if stdout is a terminal and "txt" otherwise.',
    )

    argp.add_argument(
        '--hash-algorithm',
        action='store',
        dest='algorithm',
        choices=hashlib.algorithms_guaranteed,
        default=HASH_ALGORITHM,
        help=f'Algorithm to use for checksums. Default is {HASH_ALGORITHM}.',
    )

    argp.add_argument(
        '-i',
        '--ignore-case',
        action='store_true',
        help='Matching of glob patterns is case insensitive.',
    )

    argp.add_argument('--profile', action='store', help='As for AWS CLI.')

    argp.add_argument(
        '-r',
        '--realm',
        action='store',
        help=(
            'Lava realm name. If not specified, the environment variable LAVA_REALM must be set.'
        ),
    )

    argp.add_argument(
        '-t',
        '--table',
        action='store',
        default='jobs',
        help=(
            'Extract from the specified table. This can be one of'
            ' jobs, connections, s3triggers (or triggers) or realms.'
            ' Any unique initial sequence is accepted. The default is "jobs".'
        ),
    )

    argp.add_argument(
        '-v',
        '--verbose',
        action='count',
        dest='severity',
        default=0,
        help=(
            'Increase verbosity. By default, only checksum errors, updates etc'
            ' are reported. Can be specified multiple times.'
        ),
    )
    argp.add_argument('--version', action='version', version=__version__)

    # ----------------------------------------
    # Sub parsers

    subp = argp.add_subparsers(required=True)

    # . . . . . . . . . . . . . . . . . . . .
    # "check" command

    c_check = subp.add_parser('check', help='Validate checksums.')
    c_check.set_defaults(func=cmd_check)

    c_check.add_argument(
        'glob_pat',
        metavar='glob-pattern',
        nargs='*',
        action='store',
        help=(
            'Only check items with keys that match any of the specified glob'
            ' style patterns. Default behaviour is to check all items.'
        ),
    )

    # . . . . . . . . . . . . . . . . . . . .
    # "add" command

    c_add = subp.add_parser('add', help='Add missing checksums.')
    c_add.set_defaults(func=cmd_add)

    c_add.add_argument(
        '-d',
        '--dry-run',
        action='store_true',
        help='Do a dry run. No changes are made.',
    )

    c_add.add_argument(
        'glob_pat',
        metavar='glob-pattern',
        nargs='*',
        action='store',
        help=(
            'Only add checksums to items with keys that match any of the specified'
            ' glob style patterns. Default behaviour is to updated all items.'
        ),
    )

    # . . . . . . . . . . . . . . . . . . . .
    # "update" command

    c_update = subp.add_parser('update', help='Update existing checksums.')
    c_update.set_defaults(func=cmd_update)

    c_update.add_argument(
        '-d',
        '--dry-run',
        action='store_true',
        help='Do a dry run. No changes are made.',
    )

    c_update.add_argument(
        '-f',
        '--force',
        action='store_true',
        help=(
            'Update valid checksums that are different from current settings'
            ' (e.g. different algorithm). By default, valid checksums are not'
            ' updated.'
        ),
    )

    c_update.add_argument(
        'glob_pat',
        metavar='glob-pattern',
        nargs='*',
        action='store',
        help=(
            'Only update items with keys that match any of the specified glob'
            ' style patterns. Default behaviour is to update all items.'
        ),
    )

    # . . . . . . . . . . . . . . . . . . . .
    args = argp.parse_args()
    if not args.realm:
        try:
            args.realm = os.environ['LAVA_REALM']
        except KeyError:
            argp.error(
                'Lava realm must be specified via -r, --realm or LAVA_REALM environment variable.'
            )

    # Work out which DynamoDB table is involved
    for table in TABLES:  # type: TableDescriptor
        # noinspection PyUnresolvedReferences
        if table.title.startswith(args.table.lower()):
            args.table = table
            break
    else:
        argp.error(f'Unknown table: {args.table}')

    if args.table.title != 'realms' and not args.realm:
        argp.error(f'-r / --realm must be specified for the {args.table.title} table')

    # NOTICE and higher severity are always reported
    args.severity = ResultSeverity.NOTICE - args.severity

    if not args.format:
        args.format = 'tty' if os.isatty(1) else 'txt'

    return args


# ------------------------------------------------------------------------------
def dynamo_update_item_field(
    table, key_id: str, key_value: str, field_name: str, field_value: Any
) -> None:
    """
    Update one field in a DynamoDB item.

    Assumes the table does not have a sort key.

    :param table:       DynamoDB table resource.
    :param key_id:      Partition key ID.
    :param key_value:   Partition key value.
    :param field_name:  Name of the field to add/update.
    :param field_value: Value of the field to be updated.
    """

    table.update_item(
        Key={key_id: key_value},
        UpdateExpression='SET #field_name = :field_value',
        ExpressionAttributeNames={'#field_name': field_name},
        ExpressionAttributeValues={':field_value': field_value},
    )


# ---------------------------------------------------------------------------------------
def cmd_add(args: argparse.Namespace) -> None:
    """Add missing checksums."""

    aws_session = boto3.Session(profile_name=args.profile)
    table_name = args.table.name.format(realm=args.realm)
    table = aws_session.resource('dynamodb').Table(table_name)
    ts = datetime.now().astimezone(timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ')
    backup_file = f'{table_name}-{ts}.zip'
    results = []
    with ZipFile(backup_file, 'w', compresslevel=9) as zf:
        for item in dynamo_scan_table(table_name, aws_session):
            key = item[args.table.key]
            if args.glob_pat and match_none(key, args.glob_pat, ignore_case=args.ignore_case):
                continue

            if item.get(CHECKSUM_KEY_NAME):
                results.append(CheckResult(key, 'Has existing checksum', ResultSeverity.INFO))
                continue

            new_checksum = DictChecksum.for_dict(
                item, ignore=IGNORE_KEYS_GLOB, algorithm=args.algorithm
            )

            if args.dry_run:
                action = 'would be added'
            else:
                zf.writestr(
                    f'{sanitise_filename(key)}.json', json.dumps(item, indent=4, sort_keys=True)
                )
                action = 'added'
                dynamo_update_item_field(
                    table, args.table.key, key, CHECKSUM_KEY_NAME, str(new_checksum)
                )

            results.append(CheckResult(key, f'Checksum {action}', ResultSeverity.NOTICE))

        if not zf.namelist():
            # Nothing to backup -- delete the backup file
            os.unlink(backup_file)

    output_results(results, args.severity, args.format, key_field=args.table.key)


# ---------------------------------------------------------------------------------------
def cmd_update(args: argparse.Namespace) -> None:
    """Update existing object checksums in DynamoDB."""

    aws_session = boto3.Session(profile_name=args.profile)
    table_name = args.table.name.format(realm=args.realm)
    table = aws_session.resource('dynamodb').Table(table_name)
    ts = datetime.now().astimezone(timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ')
    backup_file = f'{table_name}-{ts}.zip'
    results = []
    with ZipFile(backup_file, 'w', compresslevel=9) as zf:
        for item in dynamo_scan_table(table_name, aws_session):
            key = item[args.table.key]
            if args.glob_pat and match_none(key, args.glob_pat, ignore_case=args.ignore_case):
                continue

            existing_checksum = None
            # noinspection PyBroadException
            try:
                existing_checksum = DictChecksum.from_str(item[CHECKSUM_KEY_NAME])
            except KeyError:
                results.append(CheckResult(key, 'No checksum', ResultSeverity.INFO))
                continue
            except Exception:
                # Assume the existing checksum is damaged and we need to fix it
                pass

            new_checksum = DictChecksum.for_dict(
                item, ignore=IGNORE_KEYS_GLOB, algorithm=args.algorithm
            )

            if existing_checksum == new_checksum:
                results.append(CheckResult(key, 'Checksum OK', ResultSeverity.INFO))
                continue

            if (
                not args.force
                and existing_checksum
                and existing_checksum.is_valid_for(item, ignore=IGNORE_KEYS_GLOB)
            ):
                results.append(CheckResult(key, 'Checksum different but OK', ResultSeverity.INFO))
                continue

            if args.dry_run:
                action = 'would be updated'
            else:
                zf.writestr(
                    f'{sanitise_filename(key)}.json', json.dumps(item, indent=4, sort_keys=True)
                )
                action = 'updated'
                dynamo_update_item_field(
                    table, args.table.key, key, CHECKSUM_KEY_NAME, str(new_checksum)
                )
            results.append(CheckResult(key, f'Checksum {action}', ResultSeverity.NOTICE))

        if not zf.namelist():
            # Nothing to backup -- delete the backup file
            os.unlink(backup_file)

    output_results(results, args.severity, args.format, key_field=args.table.key)


# ---------------------------------------------------------------------------------------
def cmd_check(args: argparse.Namespace) -> None:
    """Validate object checksums in DynamoDB."""

    aws_session = boto3.Session(profile_name=args.profile)
    results = []
    for item in dynamo_scan_table(args.table.name.format(realm=args.realm), aws_session):
        key = item[args.table.key]
        if args.glob_pat and match_none(key, args.glob_pat, ignore_case=args.ignore_case):
            continue

        try:
            existing_checksum = DictChecksum.from_str(item[CHECKSUM_KEY_NAME])
        except KeyError:
            results.append(CheckResult(key, 'No checksum', ResultSeverity.WARNING))
            continue
        except ValueError as e:
            results.append(CheckResult(key, str(e), ResultSeverity.ERROR))
            continue

        try:
            if not existing_checksum.is_valid_for(item, ignore=IGNORE_KEYS_GLOB):
                results.append(CheckResult(key, 'Invalid checksum', ResultSeverity.ERROR))
            else:
                results.append(CheckResult(key, 'OK', ResultSeverity.INFO))
        except Exception as e:
            results.append(CheckResult(key, str(e), ResultSeverity.ERROR))

    output_results(results, args.severity, args.format, key_field=args.table.key)


# ------------------------------------------------------------------------------
def output_results(results: list[CheckResult], severity: int, fmt: str, **kwargs) -> None:
    """
    Output results to stdout.

    If there are no results (after filtering for severity, nothing is output.

    :param results:     A list of checksum action / check results.
    :param severity:    The severity level to filter.
    :param fmt:         The name of the format to use.
    :param kwargs:      Additional arguments to pass to the Jinja renderer.
    """

    results = sorted([r for r in results if r.severity >= severity], key=lambda r: r.id)
    if not results:
        return
    print(jinja2.Template(OUTPUT_FORMATS[fmt]).render(results=results, **kwargs))


# ---------------------------------------------------------------------------------------
def main() -> int:
    """Show time."""

    args = process_cli_args()
    args.func(args)
    return 0


# ------------------------------------------------------------------------------
if __name__ == '__main__':
    # Uncomment for debugging
    # exit(main())  # noqa: ERA001
    try:
        exit(main())
    except Exception as ex:
        print(f'{PROG}: {ex}', file=sys.stderr)
        exit(1)
    except KeyboardInterrupt:
        print('Interrupt', file=sys.stderr)
        exit(2)
