#!/usr/bin/env python3

"""Utility to query map the activity of specified lava jobs."""

from __future__ import annotations

import argparse
import csv
import json
import os
import sys
from collections.abc import Iterator, Sequence
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any

import boto3
from dateutil.parser import isoparse, parse

from lava.lib.aws import dynamo_unmarshall_item

__author__ = 'Murray Andrews'

PROG = os.path.splitext(os.path.basename(sys.argv[0]))[0]

UTC = timezone.utc

END_OF_RUN_STATUS = ('complete', 'failed', 'retrying')

# ISO 8601 format to match tu_event fields in the DynamoDB events table
TS_FMT_EVENT = '%Y-%m-%dT%H:%M:%S'
# Something excel will swallow
TS_FMT_EXCEL = '%Y-%m-%d %H:%M:%S'
# Chunk job activity into intervals this size.
INTERVAL_MINUTES = 10
INTERESTING_EVENT_FIELDS = ('job_id', 'run_id', 'status', 'worker', 'ts_dispatch', 'tu_event')

# Type aliases
LavaRecord = dict[str, Any]


# ------------------------------------------------------------------------------
def process_cli_args() -> argparse.Namespace:
    """
    Process the command line arguments.

    :return:    The args namespace.
    """

    argp = argparse.ArgumentParser(prog=PROG, description='Lava event log query utility')

    argp.add_argument('--profile', action='store', help='As for AWS CLI.')
    argp.add_argument(
        '-q',
        '--quiet',
        action='store_true',
        help='Don\'t print progress messages on stderr.',
    )
    argp.add_argument('-r', '--realm', required=True, action='store', help='Lava realm name.')

    queryp = argp.add_argument_group(title='query arguments')
    queryp.add_argument(
        '-s',
        '--start',
        dest='start_dtz',
        action=StoreTimestamp,
        help=(
            'Start datetime. Preferred format is ISO 8601. If a timezone is not'
            ' specified, UTC is assumed. When using --load, the default is the'
            ' value from the source file. Otherwise, the default is the most'
            ' recent midnight (UTC).'
        ),
    )
    queryp.add_argument(
        '-e',
        '--end',
        action='store',
        dest='end_dtz',
        help=(
            'End datetime. Preferred format is ISO 8601. If a timezone is not'
            ' specified, UTC is assumed. When using --load, the default is the'
            ' value from the source file. Otherwise the default is 24 hours after'
            ' the start time.'
        ),
    )
    # TODO: Get rid of this or do it properly
    queryp.add_argument(
        '--status',
        action='store',
        help='Only include events with the given status.',
    )

    loadp = argp.add_argument_group('dump / load arguments')
    loadp.add_argument(
        '--dump',
        metavar='FILE',
        action='store',
        help=(
            'Dump the raw data into the specified file in JSON format. The'
            ' format is suitable for loading using the --load option. If both'
            ' --load and --store are used, they must be different files.'
        ),
    )
    loadp.add_argument(
        '--load',
        metavar='FILE',
        action='store',
        help=(
            'Load the raw data from the specified file instead of reading it'
            ' from DynamoDB. The file will have been produced by a previous'
            ' run using the --dump option. This allows a set of data to be'
            ' reprocessed without re-extracting the same data.'
        ),
    )

    outp = argp.add_argument_group('output arguments')
    outp.add_argument(
        '-i',
        '--interval',
        metavar='MINUTES',
        default=INTERVAL_MINUTES,
        type=int,
        help=(
            'Aggregate job activity into intervals of the specified duration'
            f' (minutes). Stick to divisors or multiples of 60.'
            f' Default is {INTERVAL_MINUTES}.'
        ),
    )

    # ----------------------------------------
    argp.add_argument(
        'job_id',
        metavar='job-id',
        nargs='*',
        help=(
            'Retrieve records for the specified job-id. Required unless --load'
            ' is used. If used with --load, this acts as a further filter on'
            ' event records loaded from the dump file.'
        ),
    )

    args = argp.parse_args()

    if args.dump and args.load and Path(args.dump).samefile(args.load):
        argp.error('--dump and --load cannot point to the same file')

    return args


# ------------------------------------------------------------------------------
class StoreTimestamp(argparse.Action):
    """Handle timestamp arguments on argparse."""

    def __call__(self, parser, namespace, values, option_string=None):
        """
        Convert timestamp strings to timezeone aware datetimes.

        If the source value does not include a timezone, UTC is assumed.
        """

        try:
            ts = parse(values)
        except Exception:
            raise argparse.ArgumentTypeError(f'{values}: bad timestamp')

        if not ts.tzinfo:
            ts.replace(tzinfo=UTC)

        setattr(namespace, self.dest, ts)


# ------------------------------------------------------------------------------
class TimeSpan:
    """
    Model a time span.

    The timespan is closed at the bottom and open at the top. The top and bottom
    will be swapped if necessary to keep start <= end.

    :param start_dt:    Start time as an ISO 8601 timestamp or datetime.
                        If no timezone is specified, UTC is assumed.
    :param end_dt:      End time as an ISO 8601 timestamp or datetime.
                        If no timezone is specified, UTC is assumed.
    """

    # --------------------------------------------------------------------------
    def __init__(self, start_dt: datetime | str, end_dt: datetime | str):
        """Create a timespan."""

        self.start_dtz = start_dt if isinstance(start_dt, datetime) else isoparse(start_dt)
        if not self.start_dtz.tzinfo:
            self.start_dtz = self.start_dtz.replace(tzinfo=UTC)

        self.end_dtz = end_dt if isinstance(end_dt, datetime) else isoparse(end_dt)
        if not self.end_dtz.tzinfo:
            self.end_dtz = self.end_dtz.replace(tzinfo=UTC)

        self.start_dtz, self.end_dtz = min(self.start_dtz, self.end_dtz), max(
            self.start_dtz, self.end_dtz
        )

    # --------------------------------------------------------------------------
    def __len__(self) -> int:
        """Get length of time span in seconds."""

        return round((self.end_dtz - self.start_dtz).total_seconds())

    # --------------------------------------------------------------------------
    def __str__(self) -> str:
        """Convert timespan to string."""
        return f'{self.start_dtz.isoformat()} --> {self.end_dtz.isoformat()}'

    # --------------------------------------------------------------------------
    def __repr__(self):
        """Convert to repr."""
        return (
            f"{self.__class__.__name__}"
            f"('{self.start_dtz.isoformat()}', '{self.end_dtz.isoformat()}')"
        )

    # --------------------------------------------------------------------------
    def __contains__(self, dt) -> bool:
        """Check if a given datetime is inside the time span."""

        return self.start_dtz <= dt < self.end_dtz

    # --------------------------------------------------------------------------
    def __le__(self, other: TimeSpan) -> bool:
        """Check if this timespan is contained within another."""

        return all(other.start_dtz <= t <= other.end_dtz for t in (self.start_dtz, self.end_dtz))

    # --------------------------------------------------------------------------
    def asdict(self) -> dict[str, datetime]:
        """Convert to a dict."""

        return {
            'start_dt': self.start_dtz,
            'end_dt': self.end_dtz,
        }

    # --------------------------------------------------------------------------
    def intersection(self, t: TimeSpan) -> TimeSpan | None:
        """
        Determine the intersection with the given timespan.

        :param t:   Timespan.
        :return:    A timespan covering the intersection or None if no overlap.
        """

        if t.start_dtz >= self.end_dtz or self.start_dtz >= t.end_dtz:
            return None

        # There is some overlap
        return self.__class__(max(self.start_dtz, t.start_dtz), min(self.end_dtz, t.end_dtz))

    # --------------------------------------------------------------------------
    def segments(self, interval: timedelta) -> Iterator[TimeSpan]:
        """
        Divide a timespan into segments.

        :param interval:        A timedelta for segment size.
        :return:                A generator of TimeSpan objects.
        """

        t = self.start_dtz
        while t < self.end_dtz:
            yield self.__class__(t, min(t + interval, self.end_dtz))
            t += interval


# ------------------------------------------------------------------------------
def json_default(obj: Any) -> Any:
    """
    Serialise non-standard objects for json.dumps().

    :param obj:             An object.
    :return:                A serialisable version. For datetime objects we just
                            convert them to a string that strptime() could handle.

    :raise TypeError:       If obj cannot be serialised.
    """

    if isinstance(obj, TimeSpan):
        return obj.asdict()

    if isinstance(obj, datetime):
        return obj.isoformat()

    try:
        return str(obj)
    except Exception:
        raise TypeError(f'Cannot serialize {type(obj)}')


# ------------------------------------------------------------------------------
def get_event_records_from_dynamo(
    realm: str,
    job_ids: Sequence[str],
    timespan: TimeSpan,
    status: str | None = None,
    aws_session: boto3.Session | None = None,
    quiet: bool = False,
) -> Iterator[LavaRecord]:
    """
    Get event records from the lava events table.

    :param realm:           Lava realm.
    :param job_ids:         A sequence of job_ids.
    :param timespan:        The time range for the extraction based on tu_event.
    :param status:          Job status. If not specified, get all.
    :param aws_session:     A boto3 Session().
    :param quiet:           If True, don't print progress messages on stderr.
    :return:                An iterator over event records.
    """

    dynamo = (aws_session or boto3.Session()).client('dynamodb')
    start_time = timespan.start_dtz.replace(microsecond=0).astimezone(UTC).strftime(TS_FMT_EVENT)
    end_time = timespan.end_dtz.replace(microsecond=0).astimezone(UTC).strftime(TS_FMT_EVENT)

    query_args = {
        'TableName': f'lava.{realm}.events',
        'IndexName': 'job_id-tu_event-index',
        'KeyConditionExpression': '#job_id = :job_id AND #tu_event BETWEEN :tu_start AND :tu_end',
        'ExpressionAttributeNames': {'#job_id': 'job_id', '#tu_event': 'tu_event'},
        'ExpressionAttributeValues': {
            # job_id is added in the loop below
            ':tu_start': {'S': start_time},
            ':tu_end': {'S': end_time},
        },
    }

    if status:
        query_args['FilterExpression'] = '#status = :status'
        query_args['ExpressionAttributeNames']['#status'] = 'status'
        query_args['ExpressionAttributeValues'][':status'] = {'S': status}

    for job_id in job_ids:
        if not quiet:
            print(f'Getting data for {job_id} ... ', file=sys.stderr, end='')
        query_args['ExpressionAttributeValues'][':job_id'] = {'S': job_id}
        paginator = dynamo.get_paginator('query')
        response_iterator = paginator.paginate(**query_args)
        run_count = 0
        for response in response_iterator:
            for item in response['Items']:
                run_count += 1
                yield dynamo_unmarshall_item(item)
        if not quiet:
            print(f'{run_count} completed job run(s) found', file=sys.stderr)


# ------------------------------------------------------------------------------
def dump_events(
    filename: str,
    realm: str,
    events: Sequence[LavaRecord],
    job_ids: Sequence[str],
    timespan: TimeSpan,
) -> None:
    """Dump the summarised event records to a file."""

    with open(filename, 'w') as fp:
        json.dump(
            {'realm': realm, 'timespan': timespan, 'jobs': job_ids, 'events': events},
            fp,
            indent=4,
            default=json_default,
        )


# ------------------------------------------------------------------------------
def load_events(
    filename: str,
    realm: str,
    job_ids: Sequence[str],
    start_dtz: datetime,
    end_dtz: datetime,
) -> tuple[Sequence[str], TimeSpan, Sequence[LavaRecord]]:
    """
    Load summarised events from a file.

    :param filename:    The name of the file.
    :param realm:       The lava realm.
    :param job_ids:     A sequence of job IDs for which to load event data. It
                        is an error if any requested job ID wasn't present in
                        the original dump that produced the file. If not specified,
                        all of the job IDs from the original request are included.
    :param start_dtz:   Timezone aware datetime for the start of the timespan of
                        interest. If None, use what's in the dump file.
    :param end_dtz:     Timezone aware datetime for the end of the timespan of
                        interest. If None, use what's in the dump file.
    :return:            A tuple (job IDs covered in the request, timespan, event records).
    """

    with open(filename) as fp:
        data = json.load(fp)

    try:
        dumped_realm = data['realm']
        dumped_jobs = set(data['jobs'])
        dumped_events = data['events']
        _ = data['timespan']
    except KeyError as e:
        raise Exception(f'{filename}: Format error -- no {e} key')

    if dumped_realm != realm:
        raise Exception(f'{filename}: Dumped realm {dumped_realm} doesn\'t match {realm}')

    # Reconstruct timespan of original dump and make sure new reauest is inside that.
    try:
        dumped_timespan = TimeSpan(**data['timespan'])
    except Exception as e:
        raise Exception(f'{filename}: {e}')

    requested_timespan = TimeSpan(
        start_dtz or dumped_timespan.start_dtz,
        end_dtz or dumped_timespan.end_dtz,
    )
    if not requested_timespan <= dumped_timespan:
        raise Exception(
            f'{filename}: Timespan {requested_timespan} is not'
            f' within dump timespan {dumped_timespan}'
        )

    if job_ids:
        requested_job_set = set(job_ids)
        if missing_jobs := requested_job_set - dumped_jobs:
            raise Exception(f'{filename}: These jobs are missing: {", ".join(missing_jobs)}')
    else:
        requested_job_set = None
        job_ids = dumped_jobs

    # Reconstruct event records and make sure only requested jobs are included
    selected_events = []
    for ev in dumped_events:
        if requested_job_set and ev['job_id'] not in requested_job_set:
            continue
        ev['tu_event'] = isoparse(ev['tu_event']).replace(tzinfo=UTC)
        if ev['tu_event'] not in requested_timespan:
            continue
        ev['runtimes'] = [TimeSpan(**r) for r in ev['runtimes']]
        selected_events.append(ev)

    return job_ids, requested_timespan, selected_events


# ------------------------------------------------------------------------------
def process_event_record(event: LavaRecord) -> LavaRecord:
    """Process the event record to extract information of interest."""

    result = {k: event[k] for k in INTERESTING_EVENT_FIELDS}
    result['runtimes'] = []

    # Individual run cycles are bracketted by ('running', end-status) pairs where
    # end-status is `complete`, `failed` or `retrying`).
    start_time = None
    for job_run in event['log']:
        if job_run['status'] == 'running':
            start_time = job_run['tu_event']
            continue
        if job_run['status'] in END_OF_RUN_STATUS:
            if not start_time:
                raise Exception(f'Bad event record: {event}')
            result['runtimes'].append(TimeSpan(start_time, job_run['tu_event']))
            start_time = None

    # If start_time is not None here we have an incomplete run -- ignore it for now
    return result


# ---------------------------------------------------------------------------------------
def main() -> int:
    """
    Show time.

    :return:    status
    """

    args = process_cli_args()
    aws_session = boto3.Session(profile_name=args.profile)

    if args.load:
        _job_ids, query_range, events = load_events(
            args.load,
            args.realm,
            args.job_id,
            args.start_dtz,
            args.end_dtz,
        )
    else:
        start_dtz_default = datetime.now(timezone.utc).replace(
            hour=0, minute=0, second=0, microsecond=0
        )
        end_dtz_default = (args.start_dtz or start_dtz_default) + timedelta(hours=24)
        query_range = TimeSpan(args.start_dtz or start_dtz_default, args.end_dtz or end_dtz_default)
        events = [
            process_event_record(rec)
            for rec in get_event_records_from_dynamo(
                args.realm, args.job_id, query_range, args.status, aws_session, args.quiet
            )
        ]

    if args.dump:
        dump_events(args.dump, args.realm, events, args.job_id, query_range)
        if not args.quiet:
            print(f'Raw data dumped to {args.dump}', file=sys.stderr)

    # ----------------------------------------
    # Each job may have had multiple runs both in separate events and in retries.
    # We want to slice all of those runs across defined timeslots to build two
    # workload maps over timeslots: one by job_id and one by worker. This gives
    # a picture of where all the load is.

    job_time_map: dict[str, list[int]] = {}
    worker_time_map: dict[str, list[int]] = {}
    job_total_times: dict[str, int] = {}
    worker_total_times: dict[str, int] = {}
    slots = list(query_range.segments(timedelta(minutes=args.interval)))

    for event in events:
        job_id = event['job_id']
        worker = event['worker']
        job_time_map.setdefault(job_id, [0] * len(slots))
        worker_time_map.setdefault(worker, [0] * len(slots))
        job_total_times.setdefault(job_id, 0)
        worker_total_times.setdefault(worker, 0)

        # Workout the time contribution of each job in each time slot.
        for n, slot in enumerate(slots):
            for run_time in event['runtimes']:
                if overlap := slot.intersection(run_time):
                    overlap_secs = len(overlap)
                    job_time_map[job_id][n] += overlap_secs
                    worker_time_map[worker][n] += overlap_secs
                    job_total_times[job_id] += overlap_secs
                    worker_total_times[worker] += overlap_secs

    writer = csv.writer(sys.stdout)
    writer.writerow(['Job ID'] + [slot.start_dtz.strftime(TS_FMT_EXCEL) for slot in slots])

    for job_id in sorted(job_total_times, key=lambda v: job_total_times[v], reverse=True):
        if job_total_times[job_id]:
            writer.writerow([job_id, *job_time_map[job_id]])

    writer.writerow([])
    writer.writerow(['Worker'] + [slot.start_dtz.strftime(TS_FMT_EXCEL) for slot in slots])
    for worker in sorted(worker_total_times, key=lambda v: worker_total_times[v], reverse=True):
        # for worker, jobtimes in worker_time_map.items():
        writer.writerow([worker, *worker_time_map[worker]])

    return 0


# ------------------------------------------------------------------------------
if __name__ == '__main__':
    # Uncomment for debugging
    # exit(main())  # noqa: ERA001
    try:
        exit(main())
    except Exception as ex:
        print(f'{PROG}: {ex}', file=sys.stderr)
        exit(1)
    except KeyboardInterrupt:
        print('Interrupt', file=sys.stderr)
        exit(2)
