#!/usr/bin/env python3

"""Set the AMI in a lava CloudFormation stack."""

from __future__ import annotations

import argparse
import sys
from collections.abc import Callable, Collection, Iterable
from dataclasses import dataclass
from fnmatch import fnmatchcase
from pathlib import Path
from types import MappingProxyType
from typing import Any, Union

import boto3
from colorama import Fore, Style
from tabulate import tabulate

from lava.version import __version__

__author__ = 'Murray Andrews'

PROG = Path(sys.argv[0]).stem

# Only include this many recent images in the selection process.
IMAGE_LIMIT = 5

# A lava worker ClpudFormation stack must have at least these parameters or its
# not a worker. Unfortunately prior to lava v7.1.0 the parameter was sakAmiId
# which has now changed to amiId.
WORKER_STACK_PARAMS = {'realm', 'worker', 'amiId'}
WORKER_STACK_PARAMS_OLD = {'realm', 'worker', 'sakAmiId'}

TABLE_FORMAT = 'fancy_grid'

STACK_UPDATE_WAIT_SECS = 10 * 60
STACK_UPDATE_POLL_SECS = 15

# Some type aliases
Boto3Image = Any
Boto3Stack = Any
Strings = Union[str, Collection[str]]

# Special codes for user interaction
RESPONSE_QUIT = -2
RESPONSE_SKIP = -1


# ------------------------------------------------------------------------------
class Fmt:
    """Wrapper for formatting text with colorama."""

    _FMT = MappingProxyType(
        {
            'title': Fore.BLUE + Style.BRIGHT,
            'emphasis': Fore.GREEN,
            'bold': Style.BRIGHT,
            'warning': Fore.MAGENTA,
            'error': Fore.RED,
            'ok': Fore.GREEN,
            'faint': Style.DIM,
        }
    )

    # Colour Wheel
    _WHEEL = (
        Fore.GREEN,
        Fore.BLUE,
        Fore.MAGENTA,
        Fore.CYAN,
        Fore.YELLOW,
        Fore.RED,
        Style.DIM,
    )

    def __init__(self):
        """Init."""

        self._wheel = 0

    def __getattr__(self, style: str) -> Callable[[str], str]:
        """Return a formatting function e.g. Fmt().title(s)."""
        return lambda s: f'{self._FMT[style]}{s}{Style.RESET_ALL}'

    def wheel(self, s: str, new: bool = False) -> str:
        """
        Style strings successively around a style wheel.

        When it runs off the end of the wheel, the strings are unchanged.

        :param s:       String to style.
        :param new:     Reset the wheel.
        :return:        The syled text.
        """

        if new:
            self._wheel = 0

        if self._wheel >= len(self._WHEEL):
            return s

        try:
            return f'{self._WHEEL[self._wheel]}{s}{Style.RESET_ALL}'
        finally:
            self._wheel += 1


F = Fmt()


# ------------------------------------------------------------------------------
@dataclass(frozen=True)
class AttrExtractor:
    """For extracting attributes from objects."""

    name: str
    extractor: Callable


# EC2 image attributes shown in the selection table.
IMAGE_ATTRS = (
    AttrExtractor('AMI ID', lambda i: i.image_id),
    AttrExtractor('AMI Name', lambda i: i.x_cname),
    AttrExtractor('Created', lambda i: i.creation_date[:10]),
    AttrExtractor('Python', lambda i: tag_value('PythonVersion', i.tags)),
    AttrExtractor('Arch', lambda i: i.architecture),
)

# ClpudFormation stack attribures shown in the worker stack list
STACK_ATTRS = (
    AttrExtractor('Stack Name', lambda s: s.name),
    AttrExtractor('Realm', lambda s: stack_param_value('realm', s)),
    AttrExtractor('Worker', lambda s: stack_param_value('worker', s)),
    AttrExtractor(
        'Updated',
        lambda s: s.last_updated_time.strftime('%Y-%m-%d') if s.last_updated_time else None,
    ),
    AttrExtractor(
        'AMI ID', lambda s: stack_param_value('amiId', s) or stack_param_value('sakAmiId', s)
    ),
)


# ------------------------------------------------------------------------------
def obj_to_dict(obj: Any, attrs: Iterable[AttrExtractor]) -> dict[str, Any]:
    """
    Extract attributes of interest from an object into a dict.

    :param obj:     An arbitrary object.
    :param attrs:   An iterable of attribute extractor specs. Each one is a
                    tuple (attr name, extractor function)
    :return:        A dictionary of interesting attributes.
    """

    return {a.name: a.extractor(obj) for a in attrs}


# ------------------------------------------------------------------------------
def tag_value(tag_name: str, tags: list[dict[str, str]]) -> str | None:
    """
    Extract value of a specified tag from an AWS tags structure.

    :param tag_name:    The tag name.
    :param tags:        A list of tag dicts.
    :return:            The tag value or None if not present.
    """

    for t in tags:
        if t['Key'] == tag_name:
            return t['Value']

    return None


# ------------------------------------------------------------------------------
def stack_param_value(parameter: str, stack: Boto3Stack) -> str | None:
    """
    Extract value of a CloudFormation stack parameter from a boto3 Stack instance.

    :param parameter:   The parameter name.
    :param stack:       A boto3 Stack instance.
    :return:            The parameter value or None if not present.
    """

    # Parameter list contains dicts: {'ParameterKey': ..., 'ParameterValue': ...}
    for p in stack.parameters or []:
        if p['ParameterKey'] == parameter:
            return p['ParameterValue']

    return None


# ------------------------------------------------------------------------------
def process_cli_args() -> argparse.Namespace:
    """
    Process the command line arguments.

    :return:    The args namespace.
    """

    argp = argparse.ArgumentParser(
        prog=PROG,
        description='Manage the AMIs used in lava worker CloudFormation stacks.',
    )

    argp.add_argument(
        '-n',
        action='store',
        type=int,
        default=IMAGE_LIMIT,
        help=(
            'Only include specified number of most recent images of each type'
            f' in the selection list. Default {IMAGE_LIMIT}.'
        ),
    )

    argp.add_argument(
        '--sak',
        action='store_true',
        help=(
            'Include SAK AMIs in the compatibile AMI list.'
            ' By default, only lava AMIs are included.'
        ),
    )

    argp.add_argument('--profile', action='store', help='As for AWS CLI.')

    argp.add_argument(
        '-U',
        '--update',
        action='store_true',
        help=(
            'Initiate an interactive update process to allow a new AMI to be'
            ' applied for selected stacks. If specified, one or more stack'
            ' patterns must be specified (no single *) to make it harder to'
            ' maniacally update a whole bunch of stacks in one go.'
            ' You can thank me later.'
        ),
    )

    argp.add_argument('-v', '--version', action='version', version=__version__)

    argp.add_argument(
        '-W',
        '--no-wait',
        action='store_false',
        dest='wait',
        help='Don\'t wait for CloudFormation stack updates to complete.',
    )

    argp.add_argument(
        'stack_patterns',
        metavar='STACK-NAME',
        action='store',
        nargs='*',
        help=(
            'CloudFormation stack name for a lava worker. Glob style patterns'
            ' can be used. If not specified, or any of the patterns is *, the'
            ' -U / --update option is not permitted.'
        ),
    )

    args = argp.parse_args()

    if args.update and ('*' in args.stack_patterns or not args.stack_patterns):
        argp.error('Update mode requires one or more stack patterns and * is not permitted')

    return args


# ------------------------------------------------------------------------------
def get_available_amis(
    pattern: Strings | None = None, aws_session: boto3.Session = None
) -> dict[str, Boto3Image]:
    """
    Get self-owned AMIs with names matching glob patterns.

    Note the real AMI name is used, not the value of the Name tag. Only AMIs
    in an `available` state are included.

    :param pattern:     A glob pattern or collection of glob patterns to match
                        against AMI names. If not specified, all stacks are
                        returned.
    :param aws_session: A boto3 Session().

    :return:            A dictionary keyed on the AMI ID.
    """

    if isinstance(pattern, str):
        pattern = (pattern,)

    ec2 = (aws_session or boto3.Session()).resource('ec2')

    images = ec2.images.filter(
        Owners=['self'], Filters=[{'Name': 'state', 'Values': ['available']}]
    )

    return {
        img.id: img
        for img in images
        if not pattern or any(fnmatchcase(img.name, p) for p in pattern)
    }


# ------------------------------------------------------------------------------
def get_cfn_stacks(
    pattern: Strings | None = None, aws_session: boto3.Session = None
) -> dict[str, Boto3Stack]:
    """
    Get CloudFormation stacks with names matching glob patterns.

    :param pattern:     A glob pattern or collection of glob patterns to match
                        against stack names. If not specified, all stacks are
                        returned.
    :param aws_session: A boto3 Session().
    :return:            A dict of boto3 Stack instances keyed on stack name.
    """

    if isinstance(pattern, str):
        pattern = (pattern,)

    cfn = (aws_session or boto3.Session()).resource('cloudformation')

    return {
        stk.name: stk
        for stk in cfn.stacks.all()
        if not pattern or any(fnmatchcase(stk.name, p) for p in pattern)
    }


# ------------------------------------------------------------------------------
def lava_worker_ami(stack: Boto3Stack) -> str:
    """
    Get the value of the AMI parameter from a lava worker stack.

    This is complicated by the change in name of this parameter in lava v7.1.0.

    :param stack:       A boto3 Stack instance.
    :return:            The parameter value of None.
    """

    return stack_param_value('amiId', stack) or stack_param_value('sakAmiId', stack)


# ------------------------------------------------------------------------------
def get_lava_worker_stacks(
    pattern: Strings | None = None, aws_session: boto3.Session = None
) -> dict[str, Boto3Stack]:
    """
    Get lava worker CloudFormation stacks with names matching glob patterns.

    :param pattern:     A glob pattern or collection of glob patterns to match
                        against stack names. If not specified, all stacks are
                        returned.
    :param aws_session: A boto3 Session().
    :return:            A dict of boto3 Stack instances keyed on stack name.
    """

    stacks = get_cfn_stacks(pattern=pattern, aws_session=aws_session)

    # Parameter list contains dicts: {'ParameterKey': ..., 'ParameterValue': ...}
    return {
        stk_name: stk
        for stk_name, stk in stacks.items()
        if stk.parameters
        and (
            WORKER_STACK_PARAMS <= {p['ParameterKey'] for p in stk.parameters}
            or WORKER_STACK_PARAMS_OLD <= {p['ParameterKey'] for p in stk.parameters}
        )
    }


# ------------------------------------------------------------------------------
def update_lava_worker_ami(stack: Boto3Stack, ami_id: str) -> None:
    """
    Update the AMI on a lava worker.

    :param stack:       A boto3 Stack() instance.
    :param ami_id:      The new AMI ID.
    """

    new_parameters = []
    for p in stack.parameters:
        if p['ParameterKey'] in ('amiId', 'sakAmiId'):
            new_parameters.append({'ParameterKey': p['ParameterKey'], 'ParameterValue': ami_id})
        else:
            new_parameters.append({'ParameterKey': p['ParameterKey'], 'UsePreviousValue': True})

    stack.update(
        UsePreviousTemplate=True,
        Parameters=new_parameters,
        Capabilities=stack.capabilities,
    )
    print(F.faint(f'Update of {stack.name} initiated'))


# ------------------------------------------------------------------------------
def get_user_int(prompt: str, range_: tuple[int, int]) -> int:
    """
    Prompt the user for an integer response.

    :param prompt:  Prompt string. Some stuff is added to this about the quit
                    and return options.
    :param range_:  A tuple containing the upper and lower limit. Must be >=0.
    :return:        The user integer response or RESPONSE_SKIP (-1) if user hit
                    return with no input or RESPONSE_QUIT (-2) if they selected
                    'q'/'Q' for quit.
    """

    prompt = ' '.join((prompt, F.faint(f'({range_[0]}-{range_[1]}, return to skip, q to quit) ')))
    while True:
        response = input(prompt).strip()
        if not response:
            return RESPONSE_SKIP
        if response in ('q', 'Q'):
            return RESPONSE_QUIT

        try:
            i = int(response)
            if not range_[0] <= i <= range_[1]:
                raise ValueError('Out of range')
        except ValueError as e:
            print(F.warning(f'Bad response: {e}'))
            continue

        return i


# ------------------------------------------------------------------------------
def get_user_yes(prompt: str) -> bool:
    """
    Prompt the user for a yes response.

    Only "yes" (case insensitive will do). Anything else (including "y") will
    be treated as a no and there are no second chances.

    :param prompt:  Prompt string. Some stuff is added to this about the need
                    to give a proper yes.
    :return:        True if yes, False otherwise.
    """

    prompt = ' '.join((prompt, F.faint('(Only "yes" means yes) ')))
    response = input(prompt)
    return response.strip().lower() == 'yes'


# ------------------------------------------------------------------------------
def main() -> int:
    """
    Do the business.

    :return:    status
    """

    args = process_cli_args()
    aws_session = boto3.Session(profile_name=args.profile)

    # ----------------------------------------
    # Lookup the CloudFormation stacks that look to be lava workers.
    # We are only interested in real machines so we skip ones with no AMI set
    worker_stacks = get_lava_worker_stacks(pattern=args.stack_patterns, aws_session=aws_session)

    worker_table_data = sorted(
        (obj_to_dict(stk, STACK_ATTRS) for stk in worker_stacks.values() if lava_worker_ami(stk)),
        key=lambda s: s['Stack Name'],
    )

    # ----------------------------------------
    # Filter our images to a limited number of the most recent ones
    images = get_available_amis(
        ('lava-*', 'sak-*') if args.sak else ('lava-*',), aws_session=aws_session
    )
    # Add some style differentiation to AMI names.
    for ami in sorted(images.values(), reverse=True, key=lambda a: a.creation_date):
        ami.x_cname = F.wheel(ami.name) if ami.name else ''

    image_table_data = [
        obj_to_dict(img, IMAGE_ATTRS)
        for img in sorted(images.values(), reverse=True, key=lambda i: i.creation_date)[: args.n]
    ]

    # ----------------------------------------
    # Augment the worker display table with its AMI name.
    for w in worker_table_data:
        try:
            w['AMI Name'] = images[w['AMI ID']].x_cname
        except KeyError:
            w['AMI Name'] = None

    # ----------------------------------------
    # Print our worker stacks and AMI tables.
    print(F.title('Lava Worker Stacks'))
    print(tabulate(worker_table_data, headers='keys', tablefmt=TABLE_FORMAT))
    print(F.title('Lava Compatible AMIs'))
    print(tabulate(image_table_data, headers='keys', tablefmt=TABLE_FORMAT, showindex=True))

    if not args.update:
        return 0

    # ----------------------------------------
    # Update mode.

    updates_to_wait_for = []
    for wtd in worker_table_data:
        print()
        ws = worker_stacks[wtd['Stack Name']]
        current_ami_id = wtd['AMI ID']

        response = get_user_int(f'New AMI for {F.bold(ws.name)}?', (0, len(image_table_data) - 1))
        if response == RESPONSE_QUIT:
            return 0
        if response == RESPONSE_SKIP:
            continue

        new_ami_id = image_table_data[response]['AMI ID']
        if current_ami_id == new_ami_id:
            print('Current AMI and requested AMI are the same -- skipping')
            continue

        if not get_user_yes(
            f'Set AMI for {F.bold(ws.name)} to'
            f' {F.bold(new_ami_id)} / {F.bold(images[new_ami_id].name)}?'
        ):
            continue

        update_lava_worker_ami(stack=ws, ami_id=new_ami_id)
        updates_to_wait_for.append(ws)

    if not args.wait or not updates_to_wait_for:
        return 0

    # Wait for updates to complete
    cfn_client = aws_session.client('cloudformation')
    waiter = cfn_client.get_waiter('stack_update_complete')
    for stack in updates_to_wait_for:
        print(F.faint('Waiting ... '), end='', flush=True)
        waiter.wait(
            StackName=stack.name,
            WaiterConfig={
                'Delay': STACK_UPDATE_POLL_SECS,
                'MaxAttempts': STACK_UPDATE_WAIT_SECS // STACK_UPDATE_POLL_SECS,
            },
        )
        # Get the most recent event and see what happened.
        for e in stack.events.limit(1):
            fmt = F.ok if e.resource_status == 'UPDATE_COMPLETE' else F.error
            print(fmt(f'{e.logical_resource_id} - {e.resource_status}'))

    return 0


# ------------------------------------------------------------------------------
if __name__ == '__main__':
    # Uncomment for debugging
    # exit(main())  # noqa: ERA001
    try:
        exit(main())
    except KeyboardInterrupt:
        print('\nInterrupt', file=sys.stderr)
    except Exception as ex:
        print(f'{PROG}: {ex}', file=sys.stderr)
        exit(1)
