#!/usr/bin/env python3
import sys
import os
import subprocess
import signal
from functools import partial
from argparse import ArgumentParser, Namespace, REMAINDER
from vre import FFmpegVideo

def _cleanup(processes: list, signum, frame):
    print("Terminating subprocesses...")
    for p in processes:
        p.terminate()
    sys.exit(1)

def get_args() -> Namespace:
    """Argument parsing"""
    parser = ArgumentParser(description="Distribute VRE processing across multiple GPUs")
    parser.add_argument("video_file", type=str, help="Path to the video file")
    parser.add_argument("--frames", type=str, help="Frame range in the format 'start..end'")
    parser.add_argument("vre_args", nargs=REMAINDER, help="Additional arguments to pass to VRE")
    args = parser.parse_args()
    return args

def main(args: Namespace):
    """main fn"""

    assert os.getenv("CUDA_VISIBLE_DEVICES") is not None
    gpu_indices = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
    assert len(gpu_indices) > 0, gpu_indices

    video_file = args.video_file
    n_gpus = len(gpu_indices)
    video = FFmpegVideo(video_file)

    if args.frames:
        start_frame, end_frame = map(int, args.frames.split(".."))
    else:
        start_frame, end_frame = 0, len(video)

    total_frames = end_frame - start_frame
    frames_per_gpu = total_frames // n_gpus

    processes = []
    for i, gpu in enumerate(gpu_indices):
        gpu_start_frame = start_frame + i * frames_per_gpu
        gpu_end_frame = gpu_start_frame + frames_per_gpu

        # Ensure the last partition captures all remaining frames
        if i == n_gpus - 1:
            gpu_end_frame = end_frame

        env = {"CUDA_VISIBLE_DEVICES": gpu, "VRE_DEVICE": "cuda"}
        cmd = ["vre", video_file, "--frames", f"{gpu_start_frame}..{gpu_end_frame}"] + args.vre_args

        env_str = " ".join([f"{k}={v}" for k, v in env.items()])
        print(f"Executing command: {env_str} {' '.join(cmd)}")

        process = subprocess.Popen(cmd, env={**os.environ.copy(), **env})
        processes.append(process)

    signal.signal(signal.SIGINT, partial(_cleanup, processes=processes))
    signal.signal(signal.SIGTERM, partial(_cleanup, processes=processes))

    # Wait for all processes to complete
    for p in processes:
        p.wait()

if __name__ == "__main__":
    main(get_args())
