#!/usr/bin/env python3
import shutil
import sys
from argparse import ArgumentParser
from pathlib import Path
from subprocess import STDOUT, CalledProcessError, check_output
from typing import Iterator, Optional

DIRECTORIES = ["src/iterative_ensemble_smoother"]


def find_clang_format_binary(clang_format: Optional[str]) -> Path:
    """
    Looks for clang-format binary by searching the PATH environment
    variable. Detect if clang-format was installed from PyPI and use the actual
    binary rather than the incredibly slow Python wrapper.
    """
    if clang_format is None:
        clang_format = shutil.which("clang-format")
        if clang_format is None:
            sys.exit("No viable executable 'clang-format' found in PATH")

    with open(clang_format, "rb") as f:
        head = f.read(512)

    if head[:2] != b"#!":
        # File does not contain shebang, assuming real clang-format
        print(f"Using clang-format: {clang_format}")
        return Path(clang_format)

    # Extract everything between '#!' and newline
    python_path = head[2:].split(b"\n")[0].decode().strip()

    # Locate the Python 'clang-format' module path
    mod_path = (
        check_output(
            [python_path, "-c", "import clang_format;print(clang_format.__file__)"]
        )
        .decode()
        .strip()
    )

    # We assume that the location of the actual binary is always in the same
    # location
    clang_format_path = Path(mod_path).parent / "data" / "bin" / "clang-format"

    print(f"Using clang-format: {clang_format_path}")
    return clang_format_path


def source_root() -> Path:
    node = Path(__file__).parent.resolve()
    while not (node / ".git").is_dir():
        if str(node) == "/":
            sys.exit("Could not find the source root (no .git directory)")
        node = node.parent
    return node


def enumerate_sources() -> Iterator[Path]:
    root = source_root()
    for directory in DIRECTORIES:
        for extension in "c", "h", "cpp", "hpp":
            pattern = f"{directory}/**/*.{extension}"
            for path in root.glob(pattern):
                yield path


def reformat(clang_format: Path, dry_run: bool, verbose: bool) -> None:
    total = 0
    need_reformat = 0
    failed_reformat = 0

    root = source_root()
    for path in enumerate_sources():
        relpath = path.relative_to(root)
        total += 1
        if verbose:
            print("checking ", path)

        try:
            check_output([clang_format, "--dry-run", "-Werror", path], stderr=STDOUT)
            continue  # This file passed the check, continue
        except CalledProcessError:
            need_reformat += 1
            if dry_run:
                print("would reformat", relpath)

        if dry_run:
            continue
        try:
            check_output([clang_format, "-i", "-Werror", path], stderr=STDOUT)
            print("reformatted", relpath)
        except CalledProcessError:
            failed_reformat += 1
            print("failed to reformat", relpath)

    if dry_run:
        print(
            f"{need_reformat} files would be reformatted, "
            f"{total - need_reformat} files would be left unchanged."
        )
        if need_reformat > 0:
            sys.exit(1)
    else:
        successfully_reformatted = need_reformat - failed_reformat
        print(
            f"{successfully_reformatted} files reformatted, "
            f"{total - successfully_reformatted} files left unchanged, "
            f"{failed_reformat} files failed to reformat."
        )
        if failed_reformat > 0:
            sys.exit(1)


def main() -> None:
    ap = ArgumentParser()
    ap.add_argument(
        "-c",
        "--check",
        action="store_true",
        default=False,
        help="Performs a check without modifying any files",
    )
    ap.add_argument(
        "--clang-format",
        type=str,
        help="Name/path of the clang-format binary",
    )
    ap.add_argument(
        "-v",
        "--verbose",
        action="store_true",
        default=False,
        help="Output verbosely",
    )

    args = ap.parse_args()

    clang_format = find_clang_format_binary(args.clang_format)
    reformat(clang_format, args.check, args.verbose)


if __name__ == "__main__":
    main()
