#!/usr/bin/env python3

import subprocess
import sys


def get_current_branch():
    """Get the name of the current branch."""
    result = subprocess.run(['git', 'symbolic-ref', '--short', 'HEAD'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    if result.returncode != 0:
        print(result.stderr.decode('utf-8'), file=sys.stderr)
        sys.exit(result.returncode)
    return result.stdout.decode('utf-8').strip()

def get_changed_files():
    """Get a list of files that have been changed in the commit."""
    result = subprocess.run(['git', 'diff', '--cached', '--name-only'], stdout=subprocess.PIPE)
    changed_files = result.stdout.decode('utf-8').strip().split('\n')
    return changed_files

def run_command(command, print_stdout: bool = False):
    """Run a shell command."""
    result = subprocess.run(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    if result.returncode != 0:
        print(f"non zero exit ({result.returncode}) for cmd: {command}")
        print(result.stderr.decode('utf-8'))
        if print_stdout:
            print(result.stdout.decode('utf-8'))
        sys.exit(result.returncode)
    return result.stdout.decode('utf-8')


def run_black_formatting():
    """Run Black for code formatting and add changes to the commit."""
    print("Running black for code formatting...")
    run_command('uv pip install black && uv run black .')
    run_command('git add .')

def run_isort():
    """Run isort to sort and deduplicate imports."""
    print("Running isort for import sorting and deduplication...")
    run_command('uv pip install isort && uv run isort .')
    run_command('git add .')

def autoflake_files():
    print("Running autoflake...")
    run_command('uv pip install autoflake && uv run find . -type f -name "*.py" ! -path "./.venv/*" -exec autoflake --ignore-init-module-imports --in-place --remove-unused-variables --remove-all-unused-imports {} +')
    run_command('git add .')

def run_vulture():
    print("Running vulture...")
    run_command("""uv run vulture --exclude ".venv,alembic/,examples/" --ignore-decorators="@pytest.fixture,@allow_unused,@mcp_app.command,@app.command" .""", True)

def ensure_not_main():
    protected_branch = "main"
    current_branch = get_current_branch()
    if current_branch == protected_branch:
        # Get the commit message from the staged commit
        result = subprocess.run(['git', 'log', '-1', '--pretty=%B', 'HEAD'],
                              stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        commit_msg = result.stdout.decode('utf-8').strip()

        # Allow commits that start with "Bump version"
        if not commit_msg.startswith("Bump version"):
            print(f"Commits directly to {protected_branch} not allowed")
            sys.exit(1)

def main():
    ensure_not_main()

    changed_files = get_changed_files()

    if any(file.endswith('.py') for file in changed_files):
        run_black_formatting()
        autoflake_files()
        run_isort()
        run_vulture()


    # Continue with the commit
    sys.exit(0)

if __name__ == '__main__':
    main()

