#!/usr/bin/env python3
from argparse import ArgumentParser
from typing import List
from pathlib import Path
from datetime import datetime
import os
from natsort import natsorted
import numpy as np
import pandas as pd
import pims
from vre import VRE
from vre.utils import FakeVideo, image_read, get_project_root
from vre.representations import build_representations_from_cfg
from vre.logger import logger
from omegaconf import OmegaConf

def _print_and_store_run_stats(run_stats: pd.DataFrame):
    print(run_stats if len(run_stats) > 1 else run_stats.T)
    (logs_dir := Path(os.getenv("VRE_LOGS_DIR", get_project_root() / "logs"))).mkdir(exist_ok=True, parents=True)
    run_stats_out_file = f"{logs_dir}/run_stats_{datetime.now().replace(microsecond=0).isoformat()}.csv"
    logger.info(f"Stored vre run log file at '{run_stats_out_file}")
    run_stats.to_csv(run_stats_out_file)

def get_args():
    """cli args"""
    parser = ArgumentParser()
    parser.add_argument("video_path", help="Path to the scene video we are processing")
    parser.add_argument("--output_path", "-o", required=True,
                        help="Path to the output directory where representations are stored")
    parser.add_argument("--cfg_path", required=True, help="Path to global YAML cfg file")
    parser.add_argument("--start_frame", type=int, default=0, help="The first frame. If not set, defaults to first.")
    parser.add_argument("--end_frame", type=int, help="End frame. If not set, defaults to length of video")
    parser.add_argument("--output_dir_exist_mode", choices=["overwrite", "skip_computed", "raise"])
    parser.add_argument("--frame_rate", type=int, help="For a directory of images. Required for optical flow")
    parser.add_argument("--batch_size", type=int)
    parser.add_argument("--output_size", nargs="+", default=["video_shape"])
    parser.add_argument("--representations", nargs="+", help="Subset of representations to run VRE for")
    parser.add_argument("--n_threads_data_storer", type=int, default=0)

    args = parser.parse_args()
    args.video_path = Path(args.video_path).absolute()
    args.output_path = Path(args.output_path).absolute()
    assert 0 < len(args.output_size) < 3, args.output_size
    args.output_size = args.output_size[0] if len(args.output_size) == 1 else tuple(map(int, args.output_size))
    args.cfg_path = Path(args.cfg_path).absolute() if args.cfg_path is not None else args.cfg_path
    return args

def get_cfg():
    """cfg and cli args"""
    args = get_args()
    cfg = OmegaConf.load(open(args.cfg_path, "r"))

    def _update_cfg_from_args(cfg, args, items: List[str]):
        for item in items:
            if getattr(args, item) is None:
                continue
            cfg[item] = getattr(args, item)
            logger.warning(f"--{item} is set to '{cfg[item]}'. Will update the cfg value")

        return cfg

    cfg.vre = _update_cfg_from_args(cfg.vre, args, ["start_frame", "end_frame", "output_dir_exist_mode",
                                                    "batch_size", "output_size", "n_threads_data_storer"])
    if args.representations is not None:
        assert (d := set(args.representations).difference(r := cfg.representations.keys())) == set(), f"{d} not in {r}"
        logger.info(f"Keeping only {args.representations} from the config")
        cfg.representations = {k: v for k, v in cfg.representations.items() if k in args.representations}
    return args, cfg

def main():
    args, cfg = get_cfg()
    if args.video_path.is_dir():
        raw_data = [image_read(f) for f in natsorted(args.video_path.glob("*.png"), key=lambda p: p.name)]
        assert all(x.shape == raw_data[0].shape for x in raw_data), f"Images shape differ in '{args.video_path}'"
        logger.info(f"--video_path is a directory. Assuming a directory of images. Found {len(raw_data)} images.")
        video = FakeVideo(np.array(raw_data, dtype=np.uint8), frame_rate=args.frame_rate)
    else:
        video = pims.Video(args.video_path)
    representations = build_representations_from_cfg(cfg.representations)
    vre = VRE(video, representations)
    run_stats = vre(args.output_path, **cfg.vre)
    _print_and_store_run_stats(run_stats)

if __name__ == "__main__":
    main()
