#!/usr/bin/env python3

"""Print lava worker status information."""

from __future__ import annotations

import argparse
import os
import sys
from collections import Counter
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from datetime import timedelta
from itertools import chain
from typing import Any

import boto3
from tabulate import tabulate, tabulate_formats

from lava.config import config, config_load
from lava.lib.datetime import now_tz
from lava.version import __version__

__author__ = 'Murray Andrews'

PROG = os.path.splitext(os.path.basename(sys.argv[0]))[0]

CW_METRIC_WINDOW = 15 * 60  # Only look back 15 mins for metrics

TABLE_FORMAT = 'fancy_grid'  # ... or 'plain'

# This is what gets displayed. Subsequent entries are additional info with -l option(s).
# Tuples are (proc_info field name, heading, alignment, description)
SHOW_INFO = (
    (
        ('_q_name', 'QUEUE', 'left', 'SQS queue name'),
        ('ApproximateNumberOfMessages', 'MSGS', 'decimal', 'Messages visible'),
        ('ApproximateNumberOfMessagesNotVisible', 'NVIS', 'decimal', 'Messages not visible'),
        ('_q_vis_timeout', 'VIS', 'right', 'Visibility timeout'),
        ('_q_retention', 'RET', 'right', 'Message retention period'),
    ),
    (
        ('_m_backlog_now', 'BCKNOW', 'decimal', 'Current backlog'),
        ('_m_backlog_avg', 'BCKAVG', 'decimal', 'Average worker backlog in the last 15 minutes'),
        ('_m_backlog_max', 'BCKMAX', 'decimal', 'Maximum worker backlog in the last 15 minutes'),
        ('_m_rundelay_avg', 'DELAVG', 'right', 'Average run delay in the last 15 minutes'),
        ('_m_rundelay_max', 'DELMAX', 'right', 'Maximum run delay in the last 15 minutes'),
    ),
    (
        ('_ec2_count', 'EC2', 'decimal', 'Number of running EC2 instances'),
        ('_ec2_type', 'EC2TYPE', 'left', 'EC2 instance type'),
    ),
)


# ------------------------------------------------------------------------------
def seconds_to_friendly(n: int | float) -> str:
    """
    Convert a number of seconds to human friendly format.

    :param n:   Number of seconds.
    :return:    A string in HH:MM:SS format.
    """

    n = int(round(n, 0))
    t = []
    # Seconds, Minutes, Hours
    for scale in 60, 60, 24:
        n, v = divmod(n, scale)
        t.append(v)
    # Days are left
    t.append(n)

    return ' '.join(f'{val}{unit}' for unit, val in zip(('d', 'h', 'm', 's'), t[::-1]) if val)


# ------------------------------------------------------------------------------
@dataclass
class CwMetric:
    """Info about a Cloudwatch metric."""

    label: str  # Goes in SHOW_INFO
    name: str  # Metric name
    stat: str  # Statistic
    period: int
    norm: Callable  # Normalise


CW_METRICS = (
    CwMetric('_m_backlog_now', 'WorkerBacklog', 'Maximum', 60, int),
    CwMetric('_m_backlog_avg', 'WorkerBacklog', 'Average', CW_METRIC_WINDOW, int),
    CwMetric('_m_backlog_max', 'WorkerBacklog', 'Maximum', CW_METRIC_WINDOW, int),
    CwMetric('_m_rundelay_avg', 'RunDelay', 'Average', CW_METRIC_WINDOW, seconds_to_friendly),
    CwMetric('_m_rundelay_max', 'RunDelay', 'Maximum', CW_METRIC_WINDOW, seconds_to_friendly),
)


# ------------------------------------------------------------------------------
def process_cli_args() -> argparse.Namespace:
    """
    Process the command line arguments.

    :return:    The args namespace.
    """

    # Build an explainer table for the output columns
    column_help = [(info[1], info[3]) for info in chain(*SHOW_INFO)]
    # Tabulate strips leading spaces so munge them in.
    pretty_column_help = '\n'.join(
        '  ' + s for s in tabulate(sorted(column_help), tablefmt='plain').splitlines()
    )

    table_formats = [tabulate_formats[i : i + 6] for i in range(0, len(tabulate_formats), 6)]
    pretty_table_formats = '\n'.join(
        '  ' + s for s in tabulate(sorted(table_formats), tablefmt='plain').splitlines()
    )

    argp = argparse.ArgumentParser(
        prog=PROG,
        description='Get status info about lava workers and their dispatch queues.',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=(
            f'output columns:\n{pretty_column_help}\n\noutput formats:\n{pretty_table_formats}'
        ),
    )

    argp.add_argument(
        '-a',
        '--all-queues',
        action='store_true',
        help=(
            'By default, only SQS queues with the "lava:function" tag set to'
            ' "worker.dispatch" are examined. This option caters for older'
            ' lava installations that may not have this tag setting by'
            ' examining all queues with names in the right form, some of'
            ' which may not be worker dispatch queues.'
        ),
    )

    # We don't use choices here because it makes a mess in the help
    argp.add_argument(
        '-f',
        '--format',
        action='store',
        default=TABLE_FORMAT,
        help=(
            'Output table format (see below). The formats supported by tabulate'
            ' (https://pypi.org/project/tabulate/) can be used. The default is'
            f' {TABLE_FORMAT}.'
        ),
    )

    argp.add_argument(
        '-l',
        action='count',
        default=1,
        help='Show more information. Repeat up to 2 times to get more details.',
    )

    argp.add_argument('-r', '--realm', required=True, action='store', help='Lava realm name.')

    argp.add_argument(
        '-w',
        '--worker',
        action='store',
        help=(
            'Lava worker name prefix. If not specified, report on all workers in'
            ' the realm (assumes lava standard queue naming conventions).'
        ),
    )

    argp.add_argument('-v', '--version', action='version', version=__version__)

    args = argp.parse_args()

    if args.format not in tabulate_formats:
        argp.error(f'Table format must be one of:\n{pretty_table_formats}')

    return args


# ------------------------------------------------------------------------------
def get_worker_queues(
    realm: str,
    worker: str | None = None,
    all_queues: bool = False,
    aws_session: boto3.Session | None = None,
) -> Iterable:
    """
    Get the lava worker SQS queues.

    :param realm:       Lava realm.
    :param worker:      Lava worker. If not specified, all worker queues for the
                        realm are included.
    :param all_queues:  If False (the default), only return queues that have
                        the `lava:function` tag set to `worker.dispatch`.
                        Older lava installations don't have this tag so setting
                        all to True will return all queues with names in the
                        right form (some of which may not actually be worker
                        queues).
    :param aws_session: A boto3 Session().

    :return:            An iterable of boto3 SQS Queue() objects.
    """

    sqs = (aws_session or boto3.Session()).resource('sqs')
    queues = sqs.queues.filter(
        QueueNamePrefix=f'lava-{realm}-{worker}' if worker else f'lava-{realm}-'
    )

    if all_queues:
        return queues

    sqs_client = sqs.meta.client
    return (
        q
        for q in queues
        if (
            sqs_client.list_queue_tags(QueueUrl=q.url).get('Tags', {}).get('lava:function')
            == 'worker.dispatch'
        )
    )


# ------------------------------------------------------------------------------
def extract_worker_queue_info(queue) -> dict[str, Any]:
    """
    Extract SQS queue attributes into a dictionary.

    :param queue:       A boto3 SQS service Queue instance.
    :return:            A dict of SQS queue attributes.
    """

    info: dict = dict(queue.attributes)

    info['_q_name'] = info['QueueArn'].split(':')[-1]
    info['_q_vis_timeout'] = seconds_to_friendly(int(info['VisibilityTimeout']))
    info['_q_retention'] = seconds_to_friendly(int(info['MessageRetentionPeriod']))
    info['_lava_realm'], info['_lava_worker'] = info['_q_name'].split('-', 2)[1:]

    return info


# ------------------------------------------------------------------------------
def add_worker_metrics(workers_info: list[dict[str, Any]], aws_session: boto3.Session) -> None:
    """
    Augment the worker info with metric data.

    :param workers_info:    A list of worker info dicts that get augmented with
                            extra info.
    :param aws_session:     A boto3 Session.
    """

    if not workers_info:
        return
    cloudwatch = aws_session.client('cloudwatch')
    end_time = now_tz().replace(second=0, microsecond=0)

    metric_data_queries = []
    workers_metrics_map: dict = {}
    n = 0
    for w_info in workers_info:
        for metric in CW_METRICS:
            metric_data_queries.append(
                {
                    'Id': f'w_{n}',
                    'Label': metric.name,
                    'MetricStat': {
                        'Metric': {
                            'Namespace': config('CW_NAMESPACE'),
                            'MetricName': metric.name,
                            'Dimensions': [
                                {'Name': 'Realm', 'Value': w_info['_lava_realm']},
                                {'Name': 'Worker', 'Value': w_info['_lava_worker']},
                            ],
                        },
                        'Period': metric.period,
                        'Stat': metric.stat,
                    },
                }
            )
            # We need to be able to map our results back to specific worker
            # info dicts using the Id field.
            workers_metrics_map[f'w_{n}'] = (metric, w_info)
            n += 1

    paginator = cloudwatch.get_paginator('get_metric_data')
    for page in paginator.paginate(
        MetricDataQueries=metric_data_queries,
        StartTime=end_time - timedelta(seconds=CW_METRIC_WINDOW),
        EndTime=end_time,
        ScanBy='TimestampDescending',
    ):
        for metric_data in page['MetricDataResults']:
            metric, w_info = workers_metrics_map[metric_data['Id']]
            w_info[metric.label] = (
                metric.norm(metric_data['Values'][0]) if metric_data['Values'] else None
            )


# ------------------------------------------------------------------------------
def tag_value(tag_name: str, tags: list[dict[str, str]]) -> str | None:
    """
    Extract value of a specified tag from an AWS tags structure.

    :param tag_name:    The tag name.
    :param tags:        A list of tag dicts.
    :return:            The tag value or None if not present.
    """

    for t in tags:
        if t['Key'] == tag_name:
            return t['Value']

    return None


# ------------------------------------------------------------------------------
def add_worker_ec2info(workers_info: list[dict[str, Any]], aws_session: boto3.Session) -> None:
    """
    Augment the worker info with EC2 instance data.

    Workers are assumed to have the same name as the queue they service.

    :param workers_info:    A list of worker info dicts that get augmented with
                            extra info.
    :param aws_session:     A boto3 Session.
    """

    if not workers_info:
        return
    ec2 = aws_session.resource('ec2')

    worker_instances = ec2.instances.filter(
        Filters=[
            {
                'Name': 'instance-state-name',
                'Values': ['running'],
            },
            {
                'Name': 'tag:Name',
                'Values': [w['_q_name'] for w in workers_info],
            },
        ]
    )
    instance_counts = Counter(tag_value('Name', i.tags) for i in worker_instances)
    instance_types = {tag_value('Name', i.tags): i.instance_type for i in worker_instances}

    for w in workers_info:
        w['_ec2_count'] = instance_counts.get(w['_q_name'])
        w['_ec2_type'] = instance_types.get(w['_q_name'])


# ------------------------------------------------------------------------------
def main() -> int:
    """
    Do the business.

    :return:    status
    """

    args = process_cli_args()
    aws_session = boto3.Session()

    config_load(args.realm, aws_session=aws_session)

    queues = get_worker_queues(
        args.realm, worker=args.worker, all_queues=args.all_queues, aws_session=aws_session
    )
    queues = list(queues)
    workers_info = [extract_worker_queue_info(q) for q in queues]
    if not workers_info:
        return 0

    if args.l > 1:
        add_worker_metrics(workers_info, aws_session)

    if args.l > 2:
        add_worker_ec2info(workers_info, aws_session)

    show_info = tuple(chain(*SHOW_INFO[0 : args.l]))
    print(
        tabulate(
            [[w[field[0]] for field in show_info] for w in workers_info],
            headers=[hdr[1] for hdr in show_info],
            colalign=[hdr[2] for hdr in show_info],
            tablefmt=args.format,
            floatfmt='.1f',
        )
    )

    return 0


# ------------------------------------------------------------------------------
if __name__ == '__main__':
    # Uncomment for debugging
    # exit(main())  # noqa: ERA001
    try:
        exit(main())
    except Exception as ex:
        print(f'{PROG}: {ex}', file=sys.stderr)
        exit(1)
