#!/usr/bin/env python3
"""vre_reader tool -- iterates over a directory output by vre (in the --output_path argument)"""
from argparse import ArgumentParser, Namespace
import sys
from vre.readers import MultiTaskDataset
from vre.representations import build_representations_from_cfg, add_external_repositories
from vre.utils import abs_path, lo, ReprOut, MemoryData, image_write, collage_fn, str_topk
from vre.logger import vre_logger as logger
from vre_repository import get_vre_repository
from pprint import pformat
from torch.utils.data import DataLoader
from tqdm import tqdm
import random
import numpy as np

def get_args() -> Namespace:
    """CLI args"""
    parser = ArgumentParser()
    parser.add_argument("dataset_path", type=abs_path)
    parser.add_argument("--config_path", type=abs_path, required=True)
    parser.add_argument("--mode", choices=["read_one_batch", "iterate_all_data"], default="read_one_batch")
    parser.add_argument("--external_representations", "-I", nargs="+", default=[],
                        help="Path to external reprs. Format: /path/to/file.py:fn_name. fn -> [Representation]")
    parser.add_argument("--external_repositories", "-J", nargs="+", default=[],
                        help="Path to external reprs. Format: /path/to/file.py:fn_name. fn -> {str: Type[Repr]}")
    parser.add_argument("--batch_size", type=int, default=5)
    parser.add_argument("--num_workers", default=0, type=int)
    parser.add_argument("--handle_missing_data", default="raise")
    parser.add_argument("--normalization")
    parser.add_argument("--save_collages", action="store_true")
    args = parser.parse_args()
    return args

def main(args: Namespace):
    """main fn"""
    representation_types = add_external_repositories(args.external_repositories, get_vre_repository())
    # https://gitlab.com/video-representations-extractor/video-representations-extractor/-/issues/83
    _representations = build_representations_from_cfg(cfg=args.config_path, representation_types=representation_types,
                                                      external_representations=args.external_representations)
    representations = [r for r in _representations if r.name in [p.name for p in args.dataset_path.iterdir()]]
    if diff := set(_representations).difference(representations) != set():
        logger.warning(f"Not all representations from '{args.config_path}' were exported! Missing: {diff}")
    reader = MultiTaskDataset(args.dataset_path, task_names=[r.name for r in representations],
                              task_types={r.name: r for r in representations},
                              handle_missing_data=args.handle_missing_data, normalization=args.normalization,
                              cache_task_stats=True)
    logger.info(reader)
    logger.info("== Shapes ==")
    logger.info(pformat(reader.data_shape))

    if args.mode == "read_one_batch":
        logger.info("== Random loaded item ==")
        rand_ix = random.randint(0, len(reader) - 1)
        data, name = reader[rand_ix] # get a random item
        logger.info(f"{name=}")
        logger.info({k: v for k, v in data.items()})
        assert not any(v.isnan().any() for v in data.values())

        logger.info("== Random loaded batch using torch DataLoader ==")
        loader = DataLoader(reader, collate_fn=reader.collate_fn, batch_size=args.batch_size,
                            shuffle=True, num_workers=0)
        batch_data, names = next(iter(loader))
        logger.info(f"{names=}")
        logger.info(pformat({k: v for k, v in batch_data.items()})) # Nones are converted to 0s automagically

        print("== Plot each image in the loaded batch ==")
        img_data = {}
        for k, v in batch_data.items():
            img_data[k] = np.zeros((len(names), *reader.data_shape[k][0:2], 3))
            key = range(len(v))
            try:
                 img_data[k] = reader.name_to_task[k].make_images(ReprOut(output=MemoryData(v), key=key, frames=None))
            except Exception as e:
                logger.error(e)
                continue
        if args.save_collages:
            for i, name in enumerate(names):
                collage = collage_fn([img_data[k][i] for k in img_data.keys()],
                                     titles=[str_topk(k, 20) for k in img_data])
                image_write(collage, f"collage_{name}.png")

        logger.info(pformat({k: lo(v) for k, v in img_data.items()}, width=120))

    if args.mode == "iterate_all_data":
        loader = DataLoader(reader, collate_fn=reader.collate_fn, batch_size=args.batch_size,
                            shuffle=True, num_workers=args.num_workers)
        for _ in tqdm(iter(loader), file=sys.stdout): pass

if __name__ == "__main__":
    main(get_args())
