#!/usr/bin/env python3
"""Citadel — Citadel cluster management CLI for Rancher Desktop.

Usage:
    citadel-ctl cluster status              Show cluster state
    citadel-ctl cluster configure           Configure K3s for Cilium (disable flannel/traefik)
    citadel-ctl cluster reset               Factory-reset K3s
    citadel-ctl cluster wait                Wait for cluster readiness

    citadel-ctl image build <path> <tag>    Build image with nerdctl
    citadel-ctl image list                  List images in k8s.io namespace

    citadel-ctl node status                 Show node ages and health
    citadel-ctl node list                   List nodes with lifecycle info

    citadel-ctl lifecycle install           Install CRDs and RBAC for node lifecycle
    citadel-ctl lifecycle policies          Apply default termination policies

    citadel-ctl network install             Install Cilium CNI
    citadel-ctl network policies            Apply network policies
    citadel-ctl network test                Test network policies

    citadel-ctl monitor install             Install monitoring stack
    citadel-ctl security enable-wireguard   Enable WireGuard encryption
    citadel-ctl security enable-audit       Enable audit logging
"""

import subprocess
import sys
import time
from pathlib import Path


def get_script_dir() -> Path:
    """Get the directory containing citadel-ctl scripts."""
    return Path(__file__).parent / "scripts"


def run_script(script_name: str, extra_args: list[str] | None = None) -> int:
    """Run a citadel-ctl script."""
    script_path = get_script_dir() / script_name

    if not script_path.exists():
        print(f"Error: Script not found: {script_path}", file=sys.stderr)
        return 1

    try:
        cmd = ["bash", str(script_path)]
        if extra_args:
            cmd.extend(extra_args)
        result = subprocess.run(cmd, check=False)
        return result.returncode
    except Exception as e:
        print(f"Error running script: {e}", file=sys.stderr)
        return 1


def _get_provider():
    """Lazily import and return a RancherDesktopProvider."""
    from citadel_operator.k3s.provider import RancherDesktopProvider

    return RancherDesktopProvider()


# ---------------------------------------------------------------------------
# Cluster commands
# ---------------------------------------------------------------------------


def cluster_status() -> int:
    """Show cluster state."""
    provider = _get_provider()
    info = provider.get_status()

    print(f"State:              {info.state.value}")
    print(f"Kubernetes version: {info.kubernetes_version or 'unknown'}")
    print(f"Node name:          {info.node_name or 'n/a'}")
    print(f"Node ready:         {info.node_ready}")
    print(f"Flannel enabled:    {info.flannel_enabled}")
    print(f"Traefik enabled:    {info.traefik_enabled}")
    print(f"Container engine:   {info.container_engine}")
    return 0


def cluster_configure() -> int:
    """Configure K3s for Cilium."""
    print("Configuring K3s for Cilium (disable flannel/traefik, fix mounts)...")
    provider = _get_provider()
    if provider.configure_for_cilium():
        print("Configuration applied. Restart Rancher Desktop if prompted.")
        return 0
    print("Error: configuration failed", file=sys.stderr)
    return 1


def cluster_reset() -> int:
    """Factory-reset K3s."""
    provider = _get_provider()
    result = provider.reset_k3s()
    if result.success:
        print(f"K3s reset (was {result.previous_state.value} → {result.new_state.value})")
        return 0
    print(f"Error: reset failed: {result.error}", file=sys.stderr)
    return 1


def cluster_wait(timeout: int = 120) -> int:
    """Wait for cluster readiness."""
    print(f"Waiting for K3s API server (timeout {timeout}s)...")
    provider = _get_provider()
    deadline = time.monotonic() + timeout
    while time.monotonic() < deadline:
        if provider.is_ready():
            print("Cluster is ready")
            return 0
        time.sleep(5)
    print("Error: cluster not ready within timeout", file=sys.stderr)
    return 1


# ---------------------------------------------------------------------------
# Image commands
# ---------------------------------------------------------------------------


def image_build(context_dir: str, tag: str) -> int:
    """Build a container image with nerdctl."""
    print(f"Building image {tag} from {context_dir}...")
    provider = _get_provider()
    if provider.build_image(context_dir, tag):
        print(f"Image {tag} built successfully")
        return 0
    print("Error: image build failed", file=sys.stderr)
    return 1


def image_list() -> int:
    """List images in k8s.io namespace."""
    provider = _get_provider()
    images = provider.list_images()
    if not images:
        print("No images found in k8s.io namespace")
        return 0
    print(f"{'REPOSITORY':<50} {'TAG':<20} {'SIZE':<10}")
    for img in images:
        print(f"{img.repository:<50} {img.tag:<20} {img.size:<10}")
    return 0


# ---------------------------------------------------------------------------
# Node commands (unchanged)
# ---------------------------------------------------------------------------


def node_status() -> int:
    """Show node ages and lifecycle health."""
    return run_script("node-lifecycle.sh", extra_args=["status"])


def node_list() -> int:
    """List nodes with lifecycle info."""
    print("Nodes with lifecycle labels:")
    try:
        subprocess.run(
            [
                "kubectl", "get", "nodes",
                "-o", "custom-columns="
                "NAME:.metadata.name,"
                "AGE:.metadata.creationTimestamp,"
                "READY:.status.conditions[?(@.type==\"Ready\")].status,"
                "SCHEDULABLE:.spec.unschedulable,"
                "TERMINATE-REASON:.metadata.labels.citadel\\.dev/terminate-reason,"
                "POLICY:.metadata.labels.citadel\\.dev/policy-name",
            ],
            check=True,
        )
        return 0
    except subprocess.CalledProcessError as e:
        print(f"Error listing nodes: {e}", file=sys.stderr)
        return 1


# ---------------------------------------------------------------------------
# Lifecycle commands (unchanged)
# ---------------------------------------------------------------------------


def lifecycle_install() -> int:
    """Install CRDs and RBAC for node lifecycle management."""
    crds_dir = Path(__file__).parent / "crds"

    if not crds_dir.exists():
        print("Error: crds/ directory not found", file=sys.stderr)
        return 1

    try:
        for manifest in ["nodeterminationpolicy.yaml", "rbac.yaml", "default-policies.yaml"]:
            manifest_path = crds_dir / manifest
            if manifest_path.exists():
                print(f"  Applying {manifest}...")
                subprocess.run(
                    ["kubectl", "apply", "-f", str(manifest_path)],
                    check=True,
                )
        print("Node lifecycle CRDs, RBAC, and default policies installed.")
        return 0
    except subprocess.CalledProcessError as e:
        print(f"Error installing lifecycle resources: {e}", file=sys.stderr)
        return 1


def lifecycle_policies() -> int:
    """Show active NodeTerminationPolicies."""
    try:
        subprocess.run(
            ["kubectl", "get", "nodeterminationpolicies.citadel.dev", "-o", "wide"],
            check=True,
        )
        return 0
    except subprocess.CalledProcessError as e:
        print(f"Error listing policies: {e}", file=sys.stderr)
        return 1


# ---------------------------------------------------------------------------
# Network commands (unchanged)
# ---------------------------------------------------------------------------


def network_install() -> int:
    """Install Cilium CNI with Hubble."""
    print("Installing Cilium CNI...")
    return run_script("02-install-cilium.sh")


def network_policies() -> int:
    """Apply network policies."""
    print("Applying network policies...")
    policies_dir = Path(__file__).parent / "policies"

    try:
        for policy_file in sorted(policies_dir.glob("*.yaml")):
            print(f"  Applying {policy_file.name}...")
            subprocess.run(
                ["kubectl", "apply", "-f", str(policy_file)],
                check=True,
            )
        print("Network policies applied")
        return 0
    except subprocess.CalledProcessError as e:
        print(f"Error applying policies: {e}", file=sys.stderr)
        return 1


def network_test() -> int:
    """Test network policies."""
    print("Testing network policies...")
    return run_script("05-test-policies.sh")


# ---------------------------------------------------------------------------
# Monitor / security commands (unchanged)
# ---------------------------------------------------------------------------


def monitor_install() -> int:
    """Install monitoring stack."""
    print("Installing monitoring stack...")
    return run_script("03-install-monitoring.sh")


def security_wireguard() -> int:
    """Enable WireGuard encryption."""
    print("Enabling WireGuard encryption...")
    return run_script("06-enable-wireguard.sh")


def security_audit() -> int:
    """Enable audit logging."""
    print("Enabling audit logging...")
    return run_script("07-enable-audit-logging.sh")


# ---------------------------------------------------------------------------
# Help
# ---------------------------------------------------------------------------


def show_help():
    """Show help message."""
    print(__doc__)


# ---------------------------------------------------------------------------
# Main dispatch
# ---------------------------------------------------------------------------


def main():
    """Main CLI entry point."""
    if len(sys.argv) < 2:
        show_help()
        return 1

    command = sys.argv[1]

    # Cluster commands
    if command == "cluster":
        if len(sys.argv) < 3:
            print("Usage: citadel-ctl cluster <status|configure|reset|wait>")
            return 1

        subcommand = sys.argv[2]

        if subcommand == "status":
            return cluster_status()
        elif subcommand == "configure":
            return cluster_configure()
        elif subcommand == "reset":
            return cluster_reset()
        elif subcommand == "wait":
            timeout = 120
            if "--timeout" in sys.argv:
                idx = sys.argv.index("--timeout")
                if idx + 1 < len(sys.argv):
                    try:
                        timeout = int(sys.argv[idx + 1])
                    except ValueError:
                        print("Error: --timeout must be an integer", file=sys.stderr)
                        return 1
            return cluster_wait(timeout)
        else:
            print(f"Unknown cluster command: {subcommand}")
            return 1

    # Image commands
    elif command == "image":
        if len(sys.argv) < 3:
            print("Usage: citadel-ctl image <build|list>")
            return 1

        subcommand = sys.argv[2]

        if subcommand == "build":
            if len(sys.argv) < 5:
                print("Usage: citadel-ctl image build <path> <tag>")
                return 1
            return image_build(sys.argv[3], sys.argv[4])
        elif subcommand == "list":
            return image_list()
        else:
            print(f"Unknown image command: {subcommand}")
            return 1

    # Node commands
    elif command == "node":
        if len(sys.argv) < 3:
            print("Usage: citadel-ctl node <status|list>")
            return 1

        subcommand = sys.argv[2]

        if subcommand == "status":
            return node_status()
        elif subcommand == "list":
            return node_list()
        else:
            print(f"Unknown node command: {subcommand}")
            return 1

    # Lifecycle commands
    elif command == "lifecycle":
        if len(sys.argv) < 3:
            print("Usage: citadel-ctl lifecycle <install|policies>")
            return 1

        subcommand = sys.argv[2]

        if subcommand == "install":
            return lifecycle_install()
        elif subcommand == "policies":
            return lifecycle_policies()
        else:
            print(f"Unknown lifecycle command: {subcommand}")
            return 1

    # Network commands
    elif command == "network":
        if len(sys.argv) < 3:
            print("Usage: citadel-ctl network <install|policies|test>")
            return 1

        subcommand = sys.argv[2]

        if subcommand == "install":
            return network_install()
        elif subcommand == "policies":
            return network_policies()
        elif subcommand == "test":
            return network_test()
        else:
            print(f"Unknown network command: {subcommand}")
            return 1

    # Monitor commands
    elif command == "monitor":
        if len(sys.argv) < 3:
            print("Usage: citadel-ctl monitor <install>")
            return 1

        subcommand = sys.argv[2]

        if subcommand == "install":
            return monitor_install()
        else:
            print(f"Unknown monitor command: {subcommand}")
            return 1

    # Security commands
    elif command == "security":
        if len(sys.argv) < 3:
            print("Usage: citadel-ctl security <enable-wireguard|enable-audit>")
            return 1

        subcommand = sys.argv[2]

        if subcommand == "enable-wireguard":
            return security_wireguard()
        elif subcommand == "enable-audit":
            return security_audit()
        else:
            print(f"Unknown security command: {subcommand}")
            return 1

    # Help
    elif command in ["help", "-h", "--help"]:
        show_help()
        return 0

    else:
        print(f"Unknown command: {command}")
        show_help()
        return 1


if __name__ == "__main__":
    sys.exit(main())
