#!/usr/bin/env python3
"""vre_collage tool -- makes a collage from a directory output by vre (for all image representations)"""
from argparse import ArgumentParser, Namespace
from pathlib import Path
from pprint import pformat
from functools import partial
from tempfile import NamedTemporaryFile
from multiprocessing import Pool
import shutil
import ffmpeg
from omegaconf import OmegaConf
from natsort import natsorted
from tqdm import tqdm
import numpy as np
import torch as tr

from vre.utils import image_write, collage_fn, abs_path, reorder_dict, MemoryData, image_resize
from vre.representations import build_representations_from_cfg, add_external_repositories, Representation, ReprOut
from vre.readers.multitask_dataset import MultiTaskDataset
from vre.logger import vre_logger as logger
from vre_repository import get_vre_repository

def plot_one(data: dict[str, tr.Tensor], file_name: str, dataset_path: Path,
             name_to_task: dict[str, Representation]) -> dict[str, np.ndarray]:
    """
    Plots a single data point for all representationds from MultiTaskReader based on repr.data
    Note: RGB must be here too always, use data.get("rgb") if you want, but some reprs require rgb for plotting (m2f)
    """
    img_data = {}
    for k, v in data.items():
        stem = Path(file_name).stem
        task = name_to_task[k]
        extra = None
        rgb = data["rgb"].cpu().numpy()
        if np.issubdtype(rgb.dtype, np.floating):
            rgb = (rgb * 255).astype(np.uint8)
        assert np.issubdtype(rgb.dtype, np.uint8), rgb
        if (extra_path := (dataset_path/k/"npz"/f"{stem}_extra.npz")).exists():
            extra = np.load(extra_path, allow_pickle=True)["arr_0"].item()
        task.data = ReprOut(rgb[None], MemoryData(v.cpu().numpy()[None]), key=[stem], extra=[extra])
        try:
            img_data[k] = task.make_images(task.data)[0]
        except Exception as e:
            logger.info(f"[{task}] Failed {dataset_path=}, {file_name=}")
            raise e
    return img_data

def do_one(ix: int, reader: MultiTaskDataset, output_path: Path, output_resolution: tuple[int, int] | None):
    """worker fn for making collages """
    data, name = reader[ix]
    out_path = output_path / f"{'.'.join(name.split('.')[0:-1])}.png"
    img_data = reorder_dict(plot_one(data, name, reader.path, name_to_task=reader.name_to_task), ["rgb"])
    titles = [title if len(title) < 40 else f"{title[0:19]}..{title[-19:]}" for title in img_data.keys()]
    collage = collage_fn(list(img_data.values()), titles=titles, size_px=40)
    collage = collage if output_resolution is None else image_resize(collage, *output_resolution)
    image_write(collage, out_path)

def get_args() -> Namespace:
    """CLI args"""
    parser = ArgumentParser()
    parser.add_argument("vre_export_dir", type=lambda p: Path(p).absolute(), help="Path to the dir where VRE exported")
    parser.add_argument("--config_path", required=True, type=abs_path, help="Path to the config for representations")
    parser.add_argument("--external_representations", "-I", nargs="+", default=[],
                        help="Path to external reprs. Format: /path/to/file.py:fn_name. fn -> [Representation]")
    parser.add_argument("--external_repositories", "-J", nargs="+", default=[],
                        help="Path to external reprs. Format: /path/to/file.py:fn_name. fn -> {str: Type[Repr]}")
    parser.add_argument("--output_path", "-o", required=True, type=abs_path, help="Path to the collage is stored")
    parser.add_argument("--overwrite", action="store_true")
    parser.add_argument("--video", action="store_true")
    parser.add_argument("--fps", type=float)
    parser.add_argument("--output_resolution", type=int, nargs=2)
    parser.add_argument("--normalization", help="The type of normalization. Valid 'min_max', 'standardization' or None")
    parser.add_argument("--n_workers", type=int, default=1)
    args = parser.parse_args()
    assert not args.output_path.exists() or args.overwrite, f"Output path '{args.output_path}' exists. Use --overwrite."
    assert args.video is False or args.fps is not None and args.fps > 0, "If video is set, --fps must also be set."
    return args

def main(args: Namespace):
    """main fn"""
    logger.info(f"VRE exported dir: {args.vre_export_dir}")
    logger.info(f"Output dir: {args.output_path}")
    if args.video:
        logger.info(f"Export video: '{args.output_path / 'collage.mp4'}'")

    cfg = OmegaConf.to_container(OmegaConf.load(args.config_path), resolve=True)
    representation_types = add_external_repositories(args.external_repositories, get_vre_repository())
    representations = build_representations_from_cfg(cfg=cfg, representation_types=representation_types,
                                                     external_representations=args.external_representations)

    assert "rgb" in [r.name for r in representations], "RGB(rgb) is needed for collage plot"
    logger.debug(f"Representations:\n{pformat(representations)}")
    reader = MultiTaskDataset(args.vre_export_dir, task_names=[r.name for r in representations],
                              task_types={r.name: r for r in representations}, handle_missing_data="fill_nan",
                              normalization=args.normalization, cache_task_stats=True, batch_size_stats=100)
    # Make sure all the data is on disk and won't be comptued on the fly in vre_collage. This is vre's job.
    assert all(isinstance(reader.files_per_repr[task][0], Path) for task in reader.task_names)

    logger.info(reader)
    logger.info(f"== Shapes ==\n{pformat(reader.data_shape)}")

    if args.output_path.exists():
        shutil.rmtree(args.output_path)
    args.output_path.mkdir(parents=True, exist_ok=False)

    # for ix in tqdm([0, 100, 1000, 2000, 3000, len(reader) - 100, len(reader) - 1]):
    f_do_one = partial(do_one, reader=reader, output_path=args.output_path, output_resolution=args.output_resolution)
    map_fn = map if args.n_workers == 1 else Pool(args.n_workers).imap
    _ = list(map_fn(f_do_one, tqdm(range(len(reader)))))
    logger.info(f"Collage written at '{args.output_path}'")

    if args.video:
        data = []
        # ffmpeg needs this file so we have the same png image repeated 1/fps times.
        for out_png_name in natsorted([x for x in args.output_path.iterdir()], key=lambda p: p.name):
            data.extend([f"file '{out_png_name}'", f"duration {1 / args.fps:.3f}"])
        open(fname := NamedTemporaryFile(suffix=".txt").name, "w").write("\n".join(data))
        (
            ffmpeg
            .input(fname, format="concat", safe=0)
            .output(vid_path := f"{args.output_path}/collage.mp4", vcodec="libx265", crf=28, pix_fmt="yuv420p")
            .run()
        )
        logger.info(f"Video written at '{vid_path}'")

if __name__ == "__main__":
    main(get_args())
