#!/usr/bin/env python3
"""vre_collage tool -- makes a collage from a directory output by vre (for all image representations)"""
from argparse import ArgumentParser, Namespace
from pathlib import Path
from pprint import pformat
from tempfile import NamedTemporaryFile
import shutil
import ffmpeg
from omegaconf import OmegaConf
from tqdm import trange
import numpy as np
import torch as tr

from vre.utils import image_write, collage_fn, abs_path, reorder_dict, MemoryData
from vre.representations import build_representations_from_cfg, add_external_representations, Representation, ReprOut
from vre.readers.multitask_dataset import MultiTaskDataset
from vre.logger import vre_logger as logger

def plot_one(data: dict[str, tr.Tensor], file_name: str, dataset_path: Path,
             name_to_task: dict[str, Representation]) -> dict[str, np.ndarray]:
    """
    Plots a single data point for all representationds from MultiTaskReader based on repr.data
    Note: RGB must be here too always, use data.get("rgb") if you want, but some reprs require rgb for plotting (m2f)
    """
    img_data = {}
    for k, v in data.items():
        stem = Path(file_name).stem
        task = name_to_task[k]
        extra = None
        rgb = data["rgb"] # must be here
        if (extra_path := (dataset_path/k/"npz"/f"{stem}_extra.npz")).exists():
            extra = np.load(extra_path, allow_pickle=True)["arr_0"].item()
        task.data = ReprOut(rgb.cpu().numpy()[None], MemoryData(v.cpu().numpy()[None]), key=[stem], extra=[extra])
        img_data[k] = task.make_images()[0]
    return img_data

def get_args():
    """Cli args"""
    parser = ArgumentParser()
    parser.add_argument("vre_export_dir", type=lambda p: Path(p).absolute(), help="Path to the dir where VRE exported")
    parser.add_argument("--config_path", required=True, type=abs_path, help="Path to the config for representations")
    parser.add_argument("--external_representations", "-I", nargs="+", default=[],
                        help="Path to external reprs. Format: /path/to/file.py:fn_name. fn -> {str: Representation}")
    parser.add_argument("--output_path", "-o", required=True, type=abs_path, help="Path to the collage is stored")
    parser.add_argument("--overwrite", action="store_true")
    parser.add_argument("--video", action="store_true")
    parser.add_argument("--fps", type=float)
    parser.add_argument("--output_resolution", type=int, nargs=2)
    args = parser.parse_args()
    assert not args.output_path.exists() or args.overwrite, f"Output path '{args.output_path}' exists. Use --overwrite."
    assert args.video is False or args.fps is not None and args.fps > 0, "If video is set, --fps must also be set."
    return args

def main(args: Namespace):
    """main fn"""
    logger.info(f"VRE exported dir: {args.vre_export_dir}")
    logger.info(f"Output dir: {args.output_path}")
    if args.video:
        logger.info(f"Export video: '{args.output_path / 'collage.mp4'}'")

    cfg = OmegaConf.to_container(OmegaConf.load(args.config_path), resolve=True)
    representations = build_representations_from_cfg(cfg=cfg)
    if len(args.external_representations) > 0:
        for external_path in args.external_representations:
            representations = add_external_representations(representations, external_path, cfg)
    assert "rgb" in representations, "RGB(rgb) is needed for collage plot"
    logger.debug(pformat(representations))
    reader = MultiTaskDataset(args.vre_export_dir, task_names=list(representations.keys()),
                              task_types=representations, handle_missing_data="fill_nan",
                              normalization="min_max", cache_task_stats=True, batch_size_stats=100)
    # Make sure all the data is on disk and won't be comptued on the fly in vre_collage. This is vre's job.
    assert all(isinstance(reader.files_per_repr[task][0], Path) for task in reader.task_names)

    logger.info(reader)
    logger.info(f"== Shapes ==\n{pformat(reader.data_shape)}")

    if args.output_path.exists():
        shutil.rmtree(args.output_path)
    args.output_path.mkdir(parents=True, exist_ok=False)

    names = []
    for ix in trange(len(reader), desc="collage"):
        data, name, _ = reader[ix] # get a random item
        img_data = reorder_dict(plot_one(data, name, reader.path, name_to_task=reader.name_to_task), ["rgb"])
        titles = [title if len(title) < 40 else f"{title[0:19]}..{title[-19:]}" for title in img_data.keys()]
        collage = collage_fn(list(img_data.values()), titles=titles, size_px=40)
        image_write(collage, out_name := args.output_path / f"{'.'.join(name.split('.')[0:-1])}.png")
        names.append(str(out_name))
    logger.info(f"Collage written at '{args.output_path}'")

    if args.video:
        data = []
        for name in names:
            data.extend([f"file '{name}'", f"duration {1 / args.fps:.3f}"])
        open(fname := NamedTemporaryFile(suffix=".txt").name, "w").write("\n".join(data))
        (
            ffmpeg
            .input(fname, format="concat", safe=0)
            .output(vid_path := f"{args.output_path}/collage.mp4", vcodec="libx265", crf=28, pix_fmt="yuv420p")
            .run()
        )
        logger.info(f"Video written at '{vid_path}'")

if __name__ == "__main__":
    main(get_args())
