#!/usr/bin/env python3
"""
A multi-command tool to automate steps of the release process
"""
import os
import re
import shutil
import subprocess
import sys


try:
    import tomllib
except ImportError:
    sys.exit("This script requires Python 3.11 (it needs tomllib)")
from pathlib import Path
from typing import Any, List, Union

import click
from hmsl.hashicorp_vault import (
    start_hashicorp_vault_server,
    stop_hashicorp_vault_server,
)


CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"])


ROOT_DIR = Path(__file__).parent.parent
CHANGELOG_PATH = ROOT_DIR / "CHANGELOG.md"
INIT_PATH = ROOT_DIR / "ggshield" / "__init__.py"
ACTION_PATHS = [ROOT_DIR / "actions" / "secret" / "action.yml"]
CASSETTES_DIR = ROOT_DIR / "tests" / "unit" / "cassettes"

# The branch this script must be run from, except in dev mode.
RELEASE_BRANCHES = {"^main$", r"^\d+\.\d+\.x$"}

RELEASE_BRANCHES_STR = ", ".join(f'"{x}"' for x in RELEASE_BRANCHES)


def get_version() -> str:
    from ggshield import __version__

    return __version__


def get_tag(version: str) -> str:
    return f"v{version}"


def check_run(
    cmd: List[Union[str, Path]], **kwargs: Any
) -> subprocess.CompletedProcess:
    return subprocess.run(cmd, check=True, text=True, **kwargs)


def log_progress(message: str) -> None:
    click.secho(message, fg="magenta", err=True)


def log_error(message: str) -> None:
    prefix = click.style("ERROR:", fg="red")
    click.secho(f"{prefix} {message}", err=True)


def fail(message: str) -> None:
    log_error(message)
    sys.exit(1)


def get_current_branch_name() -> str:
    try:
        proc = check_run(
            ["git", "symbolic-ref", "--short", "HEAD"], capture_output=True
        )
    except subprocess.CalledProcessError:
        fail("Could not get the current branch")
    return proc.stdout.strip()


def check_working_tree_is_on_release_branch() -> bool:
    branch_name = get_current_branch_name()
    for pattern in RELEASE_BRANCHES:
        if re.search(pattern, branch_name):
            return True
    log_error(
        f"This script must be run on a branch matching one of these patterns: {RELEASE_BRANCHES_STR}"
    )
    return False


def check_working_tree_is_clean() -> bool:
    proc = check_run(["git", "status", "--porcelain"], capture_output=True)
    lines = proc.stdout.splitlines()
    if lines:
        log_error("Working tree contains changes")
        return False
    return True


def check_working_tree_is_up_to_date() -> bool:
    branch_name = get_current_branch_name()
    remote_branch_name = f"origin/{branch_name}"

    # Check a matching remote branch exists
    proc = check_run(["git", "branch", "-r"], capture_output=True)
    if not any(x.strip() == remote_branch_name for x in proc.stdout.splitlines()):
        log_error(f"No remote branch for {branch_name}")
        return False

    # Fetch remote changes and fail if there are any
    check_run(["git", "fetch"], capture_output=True)
    proc = check_run(
        ["git", "log", f"{branch_name}..{remote_branch_name}"], capture_output=True
    )
    if proc.stdout:
        log_error("Working tree is not up-to-date")
        return False

    return True


def check_dependencies() -> bool:
    """
    We want to check we only depend on released versions of our dependencies, so we
    check the content of each entry in `package` that is in the "default" group
    """

    with (ROOT_DIR / "uv.lock").open("rb") as fp:
        dct = tomllib.load(fp)

    # Get the list of dependencies of ggshield
    packages = {pkg["name"]: pkg for pkg in dct["package"]}
    dependencies = {pkg["name"] for pkg in packages["ggshield"]["dependencies"]}

    for name in dependencies:
        source = packages[name]["source"]
        if "git" in source:
            log_error(
                f"uv.lock contains a dependency on an unreleased version of {name}"
            )
            return False

    return True


@click.group(context_settings=CONTEXT_SETTINGS)
@click.option(
    "--dev-mode",
    is_flag=True,
    help=(
        "Do not abort if the working tree contains changes or if we are not on a branch matching one of"
        f" {RELEASE_BRANCHES_STR}."
    ),
)
def main(dev_mode: bool) -> int:
    """Helper script to release ggshield. Commands should be run in this order:

    \b
    1. run-tests
    2. prepare
    3. tag
    4. publish-gh-release
    """
    checks = (
        check_working_tree_is_on_release_branch,
        check_working_tree_is_clean,
        check_working_tree_is_up_to_date,
        check_dependencies,
    )
    fails = sum(1 for check in checks if not check())
    if fails > 0:
        if dev_mode:
            log_progress("Ignoring errors because --dev-mode is set")
        else:
            fail("Use --dev-mode to ignore")

    return 0


@main.command()
def run_tests() -> None:
    """Run all tests.

    Unit-tests are run without cassettes. This ensures the recorded cassettes
    still match production reality.
    """

    # Check we have an API key
    if "GITGUARDIAN_API_KEY" not in os.environ:
        fail("Environment variable $GITGUARDIAN_API_KEY is not set")

    try:
        # If CASSETTES_DIR does not exist, tests fail, so recreate it
        log_progress("Removing cassettes")
        shutil.rmtree(CASSETTES_DIR)
        CASSETTES_DIR.mkdir()

        # Start the hashicorp Vault server for HMSL tests
        log_progress("Starting Hashicorp server for HMSL tests")
        start_hashicorp_vault_server()

        log_progress("Running unit tests")
        check_run(["pytest", "tests/unit"], cwd=ROOT_DIR)

        log_progress("Running functional tests")
        check_run(["pytest", "tests/functional"], cwd=ROOT_DIR)
    finally:
        log_progress("Restoring cassettes")
        check_run(["git", "restore", CASSETTES_DIR], cwd=ROOT_DIR)

        log_progress("Stopping Hashicorp server")
        stop_hashicorp_vault_server()


def replace_once_in_file(path: Path, src: str, dst: str, flags: int = 0) -> None:
    """Look for `src` in `path`, replace it with `dst`. Abort if no match or more than
    one were found."""
    content = path.read_text()
    content, count = re.subn(src, dst, content, flags=flags)
    if count != 1:
        fail(
            f"Did not make any change to {path}: expected 1 match for '{src}', got {count}."
        )
    path.write_text(content)


def check_version(version: str) -> None:
    # Check version is valid
    if not re.fullmatch(r"\d+\.\d+\.\d+", version):
        fail(f"'{version}' is not a valid version number")

    # Check version does not already exist
    tag = get_tag(version)
    proc = check_run(["git", "tag"], capture_output=True)
    tags = proc.stdout.splitlines()
    if tag in tags:
        fail(f"The {tag} tag already exists.")


def update_version(version: str) -> None:
    replace_once_in_file(
        INIT_PATH,
        "^__version__ = .*$",
        f'__version__ = "{version}"',
        flags=re.MULTILINE,
    )

    # Pin the version of the Docker image used by the actions
    for action_path in ACTION_PATHS:
        replace_once_in_file(
            action_path,
            r"image: 'docker://gitguardian/ggshield:(unstable|v\d+\.\d+\.\d+)'",
            f"image: 'docker://gitguardian/ggshield:v{version}'",
        )


def update_changelog() -> None:
    check_run(["scriv", "collect", "--edit"])
    # prettier and scriv disagree on some minor formatting issue.
    # Run prettier through pre-commit to fix the CHANGELOG.md.
    # Do not use `check_run()` here because if prettier reformats the file
    # (which it will), then the command exit code will be 1.
    subprocess.run(["pre-commit", "run", "prettier", "--files", CHANGELOG_PATH])


def commit_changes(version: str) -> None:
    check_run(["git", "add", CHANGELOG_PATH, "changelog.d", INIT_PATH, *ACTION_PATHS])
    message = f"chore(release): {version}"
    check_run(["git", "commit", "--message", message])


@main.command()
@click.argument("version")
def prepare(version: str) -> None:
    """Prepare the code for the release:

    \b
    - Bump the version in __init__.py and in GitHub action files
    - Update the changelog
    - Commit changes
    """
    check_version(version)
    update_version(version)
    update_changelog()
    commit_changes(version)
    log_progress(f"Done, review changes and then run `{sys.argv[0]} tag`")


@main.command()
def tag() -> None:
    """Create the tag for the version, push the release branch and the tag."""
    version = get_version()
    tag = get_tag(version)
    message = f"Releasing {version}"
    check_run(["git", "tag", "--annotate", tag, "--message", message])
    # Push the tag
    check_run(["git", "push", "origin", f"{tag}:{tag}"])
    # Push the release branch
    check_run(["git", "push"])


def get_release_notes(version: str) -> str:
    """Reads CHANGELOG.md, returns the changes for version `version`, formatted for
    `gh release`."""

    # Extract changes from CHANGELOG.md
    changes = CHANGELOG_PATH.read_text()
    start_match = re.search(f"^## {re.escape(version)} .*", changes, flags=re.MULTILINE)
    assert start_match
    start_pos = start_match.end() + 1

    end_match = re.search("^(<a |## )", changes[start_pos:], flags=re.MULTILINE)
    assert end_match

    notes = changes[start_pos : end_match.start() + start_pos]

    # Remove one level of indent
    notes = re.sub("^#", "", notes, flags=re.MULTILINE)
    return notes.strip()


@main.command()
@click.option(
    "--dry-run",
    is_flag=True,
    help="Do not publish, just print the content of the release notes",
)
def publish_gh_release(dry_run: bool = False) -> None:
    """Set the release notes of the GitHub release, then remove its "draft" status.

    GitHub CLI (https://cli.github.com/) must be installed."""

    version = get_version()
    tag = get_tag(version)

    notes = get_release_notes(version)
    if dry_run:
        print(notes)
        return

    check_run(
        [
            "gh",
            "release",
            "edit",
            tag,
            "--title",
            version,
            "--notes",
            notes,
        ]
    )
    check_run(["gh", "release", "edit", tag, "--draft=false"])


if __name__ == "__main__":
    sys.exit(main())
