#!/usr/bin/env python3

"""
Find and list lava workers.

Emulates style of ps(1) with some lava specific columns.

"""

from __future__ import annotations

import argparse
import os
import sys
from datetime import datetime, timedelta
from itertools import chain
from typing import Any

import psutil
from tabulate import tabulate

from lava.lib.argparse import ArgparserExitError, ArgparserNoExit
from lava.version import __version__

__author__ = 'Murray Andrews'

PROG = os.path.splitext(os.path.basename(sys.argv[0]))[0]

# Base info provided on prrocesses by psutil. Comment out ones not needed.
PS_INFO = [
    'cmdline',
    # 'connections',
    'cpu_percent',
    'cpu_times',
    'create_time',
    # 'cwd',
    # 'environ',
    'exe',
    'gids',
    # 'memory_full_info',
    'memory_info',
    'memory_percent',
    # 'name',
    # 'nice',
    # 'num_ctx_switches',
    'num_fds',
    'num_threads',
    # 'open_files',
    'pid',
    # 'ppid',
    'status',
    # 'terminal',
    'threads',
    'uids',
    'username',
]

# This is what gets displayed. Subsequent entries are additional info with -l option(s).
# Tuples are (proc_info field name, heafing, alignment)
SHOW_INFO = (
    (
        ('username', 'USER', 'left'),
        ('pid', 'PID', 'decimal'),
        ('_lava_realm', 'REALM', 'left'),
        ('_lava_worker', 'WORKER', 'left'),
    ),
    (
        ('cpu_percent', '%CPU', 'decimal'),
        ('memory_percent', '%MEM', 'decimal'),
        ('_rss_mb', 'RSS', 'decimal'),
        ('status', 'STATUS', 'left'),
        ('_create_time', 'STARTED', 'right'),
        ('_cpu_time', 'TIME', 'decimal'),
    ),
    (
        ('num_threads', 'THREADS', 'decimal'),
        ('num_fds', 'FILES', 'decimal'),
        ('_euid', 'UID', 'decimal'),
        ('_egid', 'GID', 'decimal'),
        ('_cmdline', 'COMMAND', 'left'),
    ),
)


# ------------------------------------------------------------------------------
def process_cli_args() -> argparse.Namespace:
    """
    Process the command line arguments.

    :return:    The args namespace.
    """

    argp = argparse.ArgumentParser(prog=PROG, description='List lava worker processes.')

    argp.add_argument(
        '-l',
        action='count',
        default=1,
        help='Show more information. Repeat up to 2 times to get more details.',
    )

    argp.add_argument('-v', '--version', action='version', version=__version__)

    return argp.parse_args()


# ------------------------------------------------------------------------------
def process_worker_args(argv: list[str]) -> argparse.Namespace:
    """
    Process the lava worker arguments.

    Only the ones we're interested in are handled. Others are ignored.

    :param argv:    Like argv forr the worker.
    :return:        The args namespace.
    """

    argp = ArgparserNoExit()

    argp.add_argument('-r', '--realm', action='store')
    argp.add_argument('-w', '--worker', action='store')

    return argp.parse_known_args(argv)[0]


# ------------------------------------------------------------------------------
def seconds_to_mins_seconds(seconds: float) -> str:
    """
    Convert seconds to a string "MM:SS.ss".

    :param seconds:     A number of seconds.
    """

    mins, secs = divmod(seconds, 60)
    return f'{int(mins):02}:{secs:05.2f}'


# ------------------------------------------------------------------------------
def ps_time_format(dt: datetime) -> str:
    """
    Format a datetime using rules similar to ps(1).

    :param dt:      The datetime of interest.
    :return:        See ps(1).
    """

    now = datetime.now()
    if now - dt < timedelta(hours=24):
        return (dt.strftime('%-I:%M') + dt.strftime('%p').lower()).rjust(7)

    if now - dt < timedelta(days=7):
        return dt.strftime('%a%I') + dt.strftime('%p').lower()

    return dt.strftime('%d%b%y')


# ------------------------------------------------------------------------------
def get_lava_workers() -> list[dict[str, Any]]:
    """
    Get process information for lava worker processes.

    :return:    A list of dicts containing 'cmdline' and 'pid'.
    """

    worker_info = []
    for proc in psutil.process_iter(attrs=PS_INFO):
        cmd = proc.info['cmdline']
        if not cmd or len(cmd) < 2:
            continue

        if 'python' in cmd[0].lower() and cmd[1].endswith(('/lava-worker', '/lava-worker.py')):
            try:
                worker_args = process_worker_args(cmd[2:])
            except ArgparserExitError:
                continue
            if not worker_args.realm or not worker_args.worker:
                continue

            # OK looks like a real worker daemon.
            proc.info['_lava_realm'] = worker_args.realm
            proc.info['_lava_worker'] = worker_args.worker
            proc.info['_euid'] = proc.info['uids'].effective
            proc.info['_egid'] = proc.info['gids'].effective
            proc.info['_rss_mb'] = proc.info['memory_info'].rss // 1024
            proc.info['_create_time'] = ps_time_format(
                datetime.fromtimestamp(proc.info['create_time'])
            )
            proc.info['_cmdline'] = ' '.join(cmd[1:])
            proc.info['_cpu_time'] = seconds_to_mins_seconds(
                proc.info['cpu_times'].user + proc.info['cpu_times'].system
            )

            worker_info.append(proc.info)

    return worker_info


# ------------------------------------------------------------------------------
def main() -> int:
    """
    Do the business.

    :return:    status
    """

    args = process_cli_args()

    lava_workers = get_lava_workers()
    if not lava_workers:
        return 0

    show_info = tuple(chain(*SHOW_INFO[0 : args.l]))
    print(
        tabulate(
            [[p[field[0]] for field in show_info] for p in lava_workers],
            headers=[hdr[1] for hdr in show_info],
            colalign=[hdr[2] for hdr in show_info],
            tablefmt='plain',
            floatfmt='.1f',
        )
    )
    return 0


# ------------------------------------------------------------------------------
if __name__ == '__main__':
    # Uncomment for debugging
    # exit(main())  # noqa: ERA001
    try:
        exit(main())
    except Exception as ex:
        print(f'{PROG}: {ex}', file=sys.stderr)
        exit(1)
