#!python
"""Example script
dave --config-file="experiment.yaml" --starting-index=0 --
"""
import logging
from pathlib import Path
from subprocess import check_call

from params_proto import cli_parse, ParamsProto, is_hidden
from pathos.multiprocessing import ProcessingPool
from ruamel.yaml import YAML
from waterbear import DefaultBear, Bear


@cli_parse
class Experiment(ParamsProto):
    """Supervised MAML in tensorflow"""
    # config_file = './experiment.yml',  # type: "configuration of the experiment"
    # starting_index = None,  # type: "hashed or integer index for the starting experiment"
    config_file: "configuration of the experiment" = './experiment.yml'
    starting_index: "hashed or integer index for the starting experiment" = 0


class RunnerConfig:
    max_concurrent: int = 4


class RunConfig(Bear):
    config: RunnerConfig = DefaultBear(None)
    env: dict = {}
    run: str = 'python main.py {args}'
    default_args: dict = {}
    args: dict = {}
    batch_args: list = []


env_serializer = lambda env: " ".join("{}={}".format(k, v) for k, v in env.items())
args_serializer = lambda kwargs: " ".join("--{} {}".format(k.replace('_', '-'), v) for k, v in kwargs.items())


def run(run_config: RunConfig):
    env_hydrated = Bear(**run_config.env)
    env_serialized = env_serializer(env_hydrated)
    logging.info(env_serialized)
    args_serialized = args_serializer(run_config.args)
    logging.info(args_serialized)
    script = run_config.run.format(args=args_serialized, env=env_serialized)
    logging.info(script)
    try:
        check_call(script, shell=True)
    except Exception as e:
        print(e)


def job(run_config: RunConfig):
    # todo: Pool.map does not pickle this inner function correctly. Need to change strategy.
    logging.info(vars(run_config))

    if 'batch_args' in run_config:
        if ('args' in run_config) and run_config.args:
            logging.warning("both batch_args and args are defined. Only batch_args are used")
        batch_args: list = run_config.batch_args
    else:
        if 'args' not in run_config:
            raise RuntimeError("Neither `batch_args` nor `args` is found in run_config, please check your config file.")
        batch_args: list = [vars(run_config.args)]

    rest_run_config = {k: v for k, v in vars(run_config).items() if k not in ['args', 'batch_args']}
    batch_run_configs = [
        Bear(**rest_run_config, args=args) for args in batch_args
    ]
    p = ProcessingPool(run_config.config.max_concurrent)
    p.map(run, batch_run_configs)


def main():
    # 1. take in yaml file, go through files and run one by one
    yaml = YAML(typ='unsafe', pure=True)

    config_path = Path(Experiment.config_file)
    parsed = yaml.load_all(config_path)
    for run_config in parsed:
        hydrated: RunConfig = DefaultBear(None, **{k: v for k, v in vars(RunConfig).items() if not is_hidden(k)})
        hydrated.update(**run_config)
        job(hydrated)


if __name__ == "__main__":
    main()
