#!/usr/bin/env python3
"""vre_reader tool -- iterates over a directory output by vre (in the --output_path argument)"""
from argparse import ArgumentParser, Namespace
import sys
from vre.readers import MultiTaskDataset
from vre.representations import build_representations_from_cfg, add_external_representations
from vre.utils import abs_path, lo
from vre.logger import vre_logger as logger
from omegaconf import OmegaConf
from pprint import pformat
from torch.utils.data import DataLoader
from tqdm import tqdm
import random
import numpy as np

def get_args() -> Namespace:
    """CLI args"""
    parser = ArgumentParser()
    parser.add_argument("dataset_path", type=abs_path)
    parser.add_argument("--config_path", type=abs_path, required=True)
    parser.add_argument("--mode", choices=["read_one_batch", "iterate_all_data"], default="read_one_batch")
    parser.add_argument("--external_representations", "-I", nargs="+", default=[],
                        help="Path to external reprs. Format: /path/to/file.py:fn_name. fn -> {str: Representation}")
    parser.add_argument("--batch_size", type=int, default=5)
    parser.add_argument("--num_workers", default=0, type=int)
    parser.add_argument("--handle_missing_data", default="raise")
    args = parser.parse_args()
    return args

def main(args: Namespace):
    """main fn"""
    cfg = OmegaConf.to_container(OmegaConf.load(args.config_path), resolve=True)
    representations = build_representations_from_cfg(cfg=cfg)
    if len(args.external_representations) > 0:
        for external_path in args.external_representations:
            representations = add_external_representations(representations, external_path, cfg)

    reader = MultiTaskDataset(args.dataset_path, task_names=list(representations.keys()),
                              task_types=representations, handle_missing_data=args.handle_missing_data,
                              normalization="min_max", cache_task_stats=True)
    logger.info(reader)
    logger.info("== Shapes ==")
    logger.info(pformat(reader.data_shape))

    if args.mode == "read_one_batch":
        logger.info("== Random loaded item ==")
        rand_ix = random.randint(0, len(reader) - 1)
        data, name, _ = reader[rand_ix] # get a random item
        logger.info(f"{name=}")
        logger.info({k: v for k, v in data.items()})
        assert not any(v.isnan().any() for v in data.values())

        logger.info("== Random loaded batch using torch DataLoader ==")
        loader = DataLoader(reader, collate_fn=reader.collate_fn, batch_size=args.batch_size,
                            shuffle=True, num_workers=0)
        batch_data, names, _ = next(iter(loader))
        logger.info(f"{names=}")
        logger.info(pformat({k: v for k, v in batch_data.items()})) # Nones are converted to 0s automagically

        print("== Plot each image in the loaded batch ==")
        img_data = {}
        for k, v in batch_data.items():
            img_data[k] = np.zeros((len(names), *reader.data_shape[k][0:2], 3))
            if v is not None and hasattr(reader.name_to_task[k], "plot_fn"):
                for i in range(len(names)):
                    img_data[k][i] = reader.name_to_task[k].plot_fn(v[i])

        logger.info(pformat({k: lo(v) for k, v in img_data.items()}, width=120))

    if args.mode == "iterate_all_data":
        loader = DataLoader(reader, collate_fn=reader.collate_fn, batch_size=args.batch_size,
                            shuffle=True, num_workers=args.num_workers)
        for _ in tqdm(iter(loader), file=sys.stdout): pass

if __name__ == "__main__":
    main(get_args())
