#!/usr/bin/env python3
#
# SPDX-License-Identifier: MIT
#
# Copyright Red Hat
# Author: David Gibson <david@gibson.dropbear.id.au>

"""Tool for working with exeter test programs

Command Line Structure:
    exetool [options] subcommand [-- further_arguments..]

The command line is divided into several parts separated by '--':

First, exetool options and subcommand:
   exetool [--version] [--help] subcommand

Further arguments include exeter programs to invoke and sometimes
specific test ids.

Examples:
    exetool --version
    exetool list -- python3 my_tests.py
    exetool list -- ./compiled_test
    exetool list -- sh test_script.sh
    exetool metadata -- python3 my_tests.py -- test_id

"""

import argparse
import json
import re
import shlex
import subprocess
import sys
from typing import List, Tuple

EXETOOL_VERSION = "0.5.1"
EXETER_PROTOCOL_VERSION = "exeter test protocol 0.4.1"

SKIP_CODE = 77
HARD_ERROR = 99

# Test identifier validation regex based on PROTOCOL.md
# Must contain only alphanumeric, dots, semicolons, underscores
# Must be non-empty
TEST_ID_PATTERN = re.compile(r'^[a-zA-Z0-9.;_]+$')

# BATS test name sanitisation regex
BATS_UNSAFE_PATTERN = re.compile(r'[`${}[\]\\"\'\n\r\t]')


def split_group(argv: List[str]) -> Tuple[List[str], List[str]]:
    """Split the first group of arguments from argv, separated by '--'

    Args:
        argv: List of strings to split

    Returns:
        (group, remainder): Tuple of (first group before '--',
                                     remaining args after '--')

    Examples:
        split_group(['a', 'b']) -> (['a', 'b'], [])
        split_group(['a', '--', 'b', 'c']) -> (['a'], ['b', 'c'])
        split_group(['a', '--', 'b', '--', 'c']) -> (['a'], ['b', '--', 'c'])
    """
    if "--" in argv:
        i = argv.index("--")
        return argv[:i], argv[i + 1:]
    else:
        return argv, []


def check_exeter_cmd(args: List[str]) -> Tuple[List[str], List[str]]:
    """Split and check arguments showing how to run an exeter test program

    Args:
        args: Arguments to split and validate

    Returns:
        Tuple of (exeter_cmd, extra_args) if valid, otherwise exits
    """
    exeter_cmd, extra_args = split_group(args)
    if not exeter_cmd:
        print("Error: No test program specified", file=sys.stderr)
        sys.exit(HARD_ERROR)

    # Verify the command is an exeter test program by checking --exeter option
    try:
        cmd = exeter_cmd + ["--exeter"]
        result = subprocess.run(cmd, capture_output=True, check=True,
                                text=True, encoding='utf-8')
        expected_output = f"{EXETER_PROTOCOL_VERSION}\n"
        if result.stdout != expected_output:
            print("Error: Unexpected exeter protocol version: " +
                  f"{result.stdout.strip()}",
                  file=sys.stderr)
            print(f"Expected: {EXETER_PROTOCOL_VERSION}", file=sys.stderr)
            sys.exit(HARD_ERROR)
    except subprocess.CalledProcessError:
        print(f"Error: '{' '.join(exeter_cmd)}' is not an exeter test program",
              file=sys.stderr)
        sys.exit(HARD_ERROR)
    except FileNotFoundError:
        print(f"Error: Command '{exeter_cmd[0]}' not found", file=sys.stderr)
        sys.exit(HARD_ERROR)

    return exeter_cmd, extra_args


def validate_identifier(identifier: str) -> bool:
    """Validate identifier format (test IDs and metadata keys)

    Must use only alphanumeric, dots, semicolons, underscores
    Must be non-empty

    Args:
        identifier: Identifier to validate

    Returns:
        True if valid, False otherwise
    """
    return bool(TEST_ID_PATTERN.match(identifier))


def test_id_list(exeter_cmd: List[str], test_args: List[str] = []) \
        -> List[str]:
    """Get list of test IDs from an exeter test program

    Args:
        exeter_cmd: Command to run the test program
        test_args: Additional arguments for the test program

    Returns:
        List of test IDs

    Raises:
        subprocess.CalledProcessError: If the test program fails
        SystemExit: If invalid test IDs are found
    """
    cmd = exeter_cmd + ["--list"] + test_args
    result = subprocess.run(cmd, capture_output=True, check=True,
                            text=True, encoding='utf-8')
    test_ids = result.stdout.splitlines()
    validated_ids = []

    for tid in test_ids:
        tid = tid.strip()
        if not validate_identifier(tid):
            print(f"Error: Invalid test identifier: {tid}", file=sys.stderr)
            sys.exit(HARD_ERROR)
        validated_ids.append(tid)

    return validated_ids


def parse_metadata_value(value: str) -> str:
    """Decode C-style escape sequences in metadata value

    Uses Python's unicode_escape codec for decoding

    Args:
        value: String with potential escape sequences

    Returns:
        Decoded string

    Raises:
        ValueError: If value has invalid encoding
    """
    try:
        return value.encode('utf-8').decode('unicode_escape')
    except (UnicodeDecodeError, ValueError) as e:
        raise ValueError(f"Invalid encoding in metadata value: {value!r}") \
            from e


def parse_metadata_line(line: str) -> tuple[str, str]:
    """Parse a single metadata line into key and value

    Format: key=value with C-style escape sequences in value
    Empty lines are not allowed

    Args:
        line: Line to parse

    Returns:
        Tuple of (key, decoded_value)

    Raises:
        ValueError: If line format is invalid or value has invalid encoding
    """
    if '=' not in line:
        raise ValueError(f"Missing '=' in metadata: {line!r}")

    # Split only on the first '=' to allow '=' in values
    key, value = line.split('=', 1)

    if not key:
        raise ValueError(f"Empty metadata key: {line!r}")

    if not validate_identifier(key):
        raise ValueError(f"Invalid metadata key {key!r}")

    # Parse and validate the value
    return key, parse_metadata_value(value)


def parse_metadata(output: str) -> dict[str, str]:
    """Parse a test program's metadata into a dictionary

    Args:
        output: Complete --metadata output

    Returns:
        Dictionary mapping keys to decoded values

    Raises:
        ValueError: If any line has invalid format or encoding
    """
    metadata = {}

    for line in output.splitlines():
        key, value = parse_metadata_line(line)
        if key in metadata:
            raise ValueError(f"Duplicate metadata key {key!r}")
        metadata[key] = value

    return metadata


def test_metadata(exeter_cmd: List[str], test_id: str) -> dict[str, str]:
    """Get metadata for a specific test program and test

    Args:
        exeter_cmd: Command to run the test program
        test_id: Test identifier to get metadata for

    Returns:
        Dictionary mapping metadata keys to decoded values

    Raises:
        subprocess.CalledProcessError: If the test program fails
        ValueError: If metadata format is invalid or has duplicate keys
    """
    cmd = exeter_cmd + ["--metadata", test_id]
    result = subprocess.run(cmd, capture_output=True, check=True,
                            text=True, encoding='utf-8')
    return parse_metadata(result.stdout)


def main() -> int:
    """Main entry point for exetool"""

    # Parse command line with -- separators
    exetool_args, remaining_args = split_group(sys.argv[1:])

    # Parse exetool arguments
    parser = argparse.ArgumentParser(
        description="Tool for working with exeter test programs",
        prog="exetool"
    )

    parser.add_argument(
        "--version",
        action="version",
        version=f"%(prog)s {EXETOOL_VERSION}"
    )

    subparsers = parser.add_subparsers(
        dest="command",
        help="Available commands",
        metavar="COMMAND"
    )

    # Probe subcommand
    subparsers.add_parser(
        "probe",
        help="Check if the given command is an exeter test program"
    )

    subparsers.add_parser(
        "list",
        help="List available tests in exeter test programs"
    )

    # Metadata subcommand
    metadata_parser = subparsers.add_parser(
        "metadata",
        help="Show metadata for specified tests or all tests"
    )
    metadata_parser.add_argument(
        "-f", "--format",
        choices=["json"],
        default="json",
        help="Output format (default: json)"
    )

    # Description subcommand
    subparsers.add_parser(
        "description",
        aliases=["desc"],
        help="Extract the description from a test's metadata"
    )

    # Avocado subcommand
    subparsers.add_parser(
        "avocado",
        help="Generate JSON manifest for Avocado runner"
    )

    # BATS subcommand
    subparsers.add_parser(
        "bats",
        help="Generate BATS script to run each test"
    )

    args = parser.parse_args(exetool_args)

    if not args.command:
        parser.print_help()
        return HARD_ERROR

    if args.command == "probe":
        return cmd_probe(args, remaining_args)
    elif args.command == "list":
        return cmd_list(args, remaining_args)
    elif args.command == "metadata":
        return cmd_metadata(args, remaining_args)
    elif args.command in ("description", "desc"):
        return cmd_desc(args, remaining_args)
    elif args.command == "avocado":
        return cmd_avocado(args, remaining_args)
    elif args.command == "bats":
        return cmd_bats(args, remaining_args)
    else:
        print(f"Unknown command: {args.command}", file=sys.stderr)
        return HARD_ERROR


def cmd_probe(args: argparse.Namespace, remaining_args: List[str]) -> int:
    """Check if the given command is an exeter test program"""

    exeter_cmd, test_args = check_exeter_cmd(remaining_args)

    # If check_exeter_cmd() returns, the command is valid
    # No output on success for script usage
    return 0


def cmd_list(args: argparse.Namespace, remaining_args: List[str]) -> int:
    """List tests in an exeter test program"""

    exeter_cmd, test_args = check_exeter_cmd(remaining_args)

    try:
        # Get validated test IDs and print them
        test_ids = test_id_list(exeter_cmd, test_args)
        for test_id in test_ids:
            print(test_id)
        return 0

    except subprocess.CalledProcessError as e:
        print(f"Error running test program: {e}", file=sys.stderr)
        return HARD_ERROR


def cmd_metadata(args: argparse.Namespace, remaining_args: List[str]) -> int:
    """Show metadata for specified tests or all tests"""

    exeter_cmd, test_args = check_exeter_cmd(remaining_args)

    try:
        # If no test IDs specified, get all tests
        if not test_args:
            test_ids = test_id_list(exeter_cmd)
        else:
            test_ids = test_args
            # Validate all test ID formats
            for testid in test_ids:
                if not validate_identifier(testid):
                    print(f"Error: Invalid test ID format: {testid}",
                          file=sys.stderr)
                    return HARD_ERROR

        # Get metadata for all specified tests
        metadata = {}
        for test_id in test_ids:
            metadata[test_id] = test_metadata(exeter_cmd, test_id)

        # Output in specified format
        if args.format == "json":
            json.dump(metadata, sys.stdout, indent=2)
            print()  # Add newline for better output formatting
        return 0

    except ValueError as e:
        print(f"Error: {e}", file=sys.stderr)
        return HARD_ERROR
    except subprocess.CalledProcessError as e:
        print(f"Error: {e.stderr.strip()}", file=sys.stderr)
        return HARD_ERROR


def cmd_desc(args: argparse.Namespace,
             remaining_args: List[str]) -> int:
    """Extract the description from a test's metadata"""

    exeter_cmd, test_args = check_exeter_cmd(remaining_args)

    # Extract test ID from test_args
    if not test_args:
        print("Error: No test ID specified", file=sys.stderr)
        return HARD_ERROR

    if len(test_args) > 1:
        print("Error: Only one test ID can be specified", file=sys.stderr)
        return HARD_ERROR
    testid = test_args[0]

    # Validate test ID format
    if not validate_identifier(testid):
        print(f"Error: Invalid test ID format: {testid}", file=sys.stderr)
        return HARD_ERROR

    try:
        metadata = test_metadata(exeter_cmd, testid)

        # Extract description, fallback to testid
        description = metadata.get('description', testid)
        print(description)
        return 0

    except subprocess.CalledProcessError as e:
        print(f"Error: {e.stderr.strip()}", file=sys.stderr)
        return HARD_ERROR


def cmd_avocado(args: argparse.Namespace, remaining_args: List[str]) -> int:
    """Generate JSON manifest for Avocado runner from --list output"""

    exeter_cmd, test_args = check_exeter_cmd(remaining_args)

    try:
        # Run the program with --list to get test IDs
        test_ids = test_id_list(exeter_cmd, test_args)

        # Generate avocado JSON manifest
        manifest = []
        program_identifier = " ".join(exeter_cmd)
        for test_id in test_ids:
            test_descriptor = {
                "kind": "exec-test",
                "uri": exeter_cmd[0],
                "identifier": f"{program_identifier}:{test_id}",
                "args": exeter_cmd[1:] + [test_id],
                "config": {"runner.exectest.exitcodes.skip": [SKIP_CODE]},
            }
            manifest.append(test_descriptor)

        # Output JSON
        json.dump(manifest, sys.stdout)
        print()  # Add newline for better output formatting
        return 0

    except subprocess.CalledProcessError as e:
        print(f"Error running test program: {e}", file=sys.stderr)
        return HARD_ERROR


def cmd_bats(args: argparse.Namespace, remaining_args: List[str]) -> int:
    """Generate BATS script from exeter test program"""

    exeter_cmd, test_args = check_exeter_cmd(remaining_args)

    try:
        # Run the program with --list to get test IDs
        test_ids = test_id_list(exeter_cmd, test_args)

        # Generate BATS script
        print("#! /usr/bin/env bats")
        print()
        print('bats_require_minimum_version 1.5.0')

        program_path = shlex.join(exeter_cmd)
        for test_id in test_ids:
            # Get metadata to extract description
            try:
                metadata = test_metadata(exeter_cmd, test_id)
                description = metadata.get('description', test_id)
            except (ValueError, subprocess.CalledProcessError):
                # Fall back to test_id if metadata fails
                description = test_id

            # Shell-quoting BATS names is not enough - it appears that
            # BATS sometimes removes at least one quoting layer and
            # can still be thrown by some shell metacharacters.  I
            # can't find definitive documentation on what characters
            # are safe, so for now, replace [], {}, $, `, ', ", \n, \r
            # and \t with _
            #
            # Replace unsafe characters with underscore for BATS test name
            bats_name = shlex.quote(BATS_UNSAFE_PATTERN.sub('_', description))

            print(f'@test {bats_name} {{')
            print(f'    run -- {program_path} {shlex.quote(test_id)}')
            print('    echo "$output"')
            print('    if [ "$status" = 77 ]; then')
            print('        skip')
            print('    elif [ "$status" != 0 ]; then')
            print('        exit "$status"')
            print('    fi')
            print('}')
            print()

        return 0

    except subprocess.CalledProcessError as e:
        print(f"Error running test program: {e}", file=sys.stderr)
        return HARD_ERROR


if __name__ == "__main__":
    sys.exit(main())
