#!/usr/bin/env python3
"""vre_reader tool -- iterates over a directory output by vre (in the --output_path argument)"""
from argparse import ArgumentParser, Namespace
from pathlib import Path
import sys
import random
from pprint import pformat
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np

from vre.readers import MultiTaskDataset
from vre.representations import build_representations_from_cfg, add_external_repositories, ReprOut
from vre.utils import lo, MemoryData, image_write, str_topk
from vre.logger import vre_logger as logger
from vre_repository import get_vre_repository
from vre_repository.utils import collage_fn

def get_args() -> Namespace:
    """CLI args"""
    parser = ArgumentParser()
    parser.add_argument("dataset_path", type=lambda p: Path(p).absolute())
    parser.add_argument("--config_path", type=lambda p: Path(p).absolute(), required=True)
    parser.add_argument("--mode", choices=["read_one_batch", "iterate_all_data"], default="read_one_batch")
    parser.add_argument("--external_representations", "-I", nargs="+", default=[],
                        help="Path to external reprs. Format: /path/to/file.py:fn_name. fn -> [Representation]")
    parser.add_argument("--external_repositories", "-J", nargs="+", default=[],
                        help="Path to external reprs. Format: /path/to/file.py:fn_name. fn -> {str: Type[Repr]}")
    parser.add_argument("--batch_size", type=int, default=5)
    parser.add_argument("--num_workers", default=0, type=int)
    parser.add_argument("--handle_missing_data", default="raise")
    parser.add_argument("--normalization")
    parser.add_argument("--save_collages", action="store_true")
    args = parser.parse_args()
    return args

def main(args: Namespace):
    """main fn"""
    representation_types = add_external_repositories(args.external_repositories, get_vre_repository())
    # https://gitlab.com/video-representations-extractor/video-representations-extractor/-/issues/83
    _representations = build_representations_from_cfg(cfg=args.config_path, representation_types=representation_types,
                                                      external_representations=args.external_representations)
    representations = [r for r in _representations if r.name in [p.name for p in args.dataset_path.iterdir()]]
    if diff := set(_representations).difference(representations) != set():
        logger.warning(f"Not all representations from '{args.config_path}' were exported! Missing: {diff}")
    reader = MultiTaskDataset(args.dataset_path, task_names=[r.name for r in representations],
                              task_types={r.name: r for r in representations},
                              handle_missing_data=args.handle_missing_data, normalization=args.normalization,
                              cache_task_stats=True)
    logger.info(reader)
    logger.info("== Shapes ==")
    logger.info(pformat(reader.data_shape))

    if args.mode == "read_one_batch":
        logger.info("== Random loaded item ==")
        rand_ix = random.randint(0, len(reader) - 1)
        data, name = reader[rand_ix] # get a random item
        logger.info(f"{name=}\n{dict(data.items())}")
        assert not any(v.isnan().any() for v in data.values())

        logger.info("== Random loaded batch using torch DataLoader ==")
        loader = DataLoader(reader, collate_fn=reader.collate_fn, batch_size=args.batch_size,
                            shuffle=True, num_workers=0)
        batch_data, names = next(iter(loader))
        logger.info(f"{names=}\n{pformat(dict(batch_data.items()))}")

        print("== Plot each image in the loaded batch ==")
        img_data = {}
        for k, v in batch_data.items():
            img_data[k] = np.zeros((len(names), *reader.data_shape[k][0:2], 3))
            key = range(len(v))
            try:
                img_data[k] = reader.name_to_task[k].make_images(ReprOut(output=MemoryData(v), key=key, frames=None))
            except Exception as e:
                logger.error(e)
                continue
        if args.save_collages:
            for i, name in enumerate(names):
                collage = collage_fn([img_data[k][i] for k in img_data.keys()],
                                     titles=[str_topk(k, 20) for k in img_data])
                image_write(collage, f"collage_{name}.png")

        logger.info(pformat({k: lo(v) for k, v in img_data.items()}, width=120))

    if args.mode == "iterate_all_data":
        loader = DataLoader(reader, collate_fn=reader.collate_fn, batch_size=args.batch_size,
                            shuffle=True, num_workers=args.num_workers)
        for _ in tqdm(iter(loader), file=sys.stdout): pass # pylint: disable=all

if __name__ == "__main__":
    main(get_args())
