#!/usr/bin/env python3
import sys
import os
import subprocess
import signal
from functools import partial
from argparse import ArgumentParser, Namespace, REMAINDER
from vre_video import VREVideo

def _cleanup(processes: list, signum, frame):
    print("Terminating subprocesses...")
    for p in processes:
        p.terminate()
    sys.exit(1)

def get_args() -> Namespace:
    """Argument parsing"""
    parser = ArgumentParser(description="Distribute VRE processing across multiple GPUs")
    parser.add_argument("video_path", type=str, help="Path to the video file")
    parser.add_argument("vre_args", nargs=REMAINDER, help="Additional arguments to pass to VRE")
    args = parser.parse_args()
    return args

def main(args: Namespace):
    """main fn"""

    assert os.getenv("CUDA_VISIBLE_DEVICES") is not None
    gpu_indices = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
    assert (n_gpus := len(gpu_indices)) > 0, gpu_indices

    if "--frames" in args.vre_args:
        args.vre_args.pop(ix := args.vre_args.index("--frames")) # pop the --frames entry
        start_frame, end_frame = map(int, args.vre_args.pop(ix).split("..")) # pop the next arg too: the actual frames
    else:
        video = VREVideo(args.video_path)
        start_frame, end_frame = 0, len(video)

    total_frames = end_frame - start_frame
    frames_per_gpu = total_frames // n_gpus

    processes = []
    for i, gpu in enumerate(gpu_indices):
        gpu_start_frame = start_frame + i * frames_per_gpu
        gpu_end_frame = gpu_start_frame + frames_per_gpu

        # Ensure the last partition captures all remaining frames
        if i == n_gpus - 1:
            gpu_end_frame = end_frame

        env = {"CUDA_VISIBLE_DEVICES": gpu, "VRE_DEVICE": "cuda"}
        cmd = ["vre", args.video_path, "--frames", f"{gpu_start_frame}..{gpu_end_frame}"] + args.vre_args

        env_str = " ".join([f"{k}={v}" for k, v in env.items()])
        print(f"Executing command: {env_str} {' '.join(cmd)}")

        process = subprocess.Popen(cmd, env={**os.environ.copy(), **env})
        processes.append(process)

    signal.signal(signal.SIGINT, partial(_cleanup, processes=processes))
    signal.signal(signal.SIGTERM, partial(_cleanup, processes=processes))

    # Wait for all processes to complete
    for p in processes:
        p.wait()

if __name__ == "__main__":
    main(get_args())
