#!python
"""Example script
dave --config-file="experiment.yaml" --starting-index=0 --
"""
import logging
from pathlib import Path
from subprocess import check_call

# noinspection PyUnresolvedReferences
from params_proto import is_hidden, cli_parse, ParamsProto
from pathos.multiprocessing import ProcessingPool
from ruamel.yaml import YAML, sys
from waterbear import DefaultBear, Bear

if sys.version_info >= (3, 6):
    script = '@cli_parse\n' \
             'class Experiment(ParamsProto):\n' \
             '    """Supervised MAML in tensorflow"""\n' \
             '    config_file: "configuration of the experiment" = "./experiment.yml"\n' \
             '    starting_index: "hashed or integer index for the starting experiment" = 0\n'
else:
    script = '@cli_parse\n' \
             'class Experiment(ParamsProto):\n' \
             '    """Supervised MAML in tensorflow"""\n' \
             '    config_file = "./experiment.yml", """configuration of the experiment"""\n' \
             '    starting_index = None, """hashed or integer index for the starting experiment"""\n'
exec(script, globals())


class RunnerConfig:
    max_concurrent = 4  # type: int


class RunConfig(Bear):
    config = DefaultBear(None)  # type: RunnerConfig
    env = {}  # type: dict
    run = 'python main.py {args}'  # type: str
    default_args = {}  # type: dict
    args = {}  # type: dict
    batch_args = []  # type: list


env_serializer = lambda env: " ".join("{}={}".format(k, v) for k, v in env.items())
args_serializer = lambda kwargs: " ".join("--{} {}".format(k.replace('_', '-'), v) for k, v in kwargs.items())


def run(run_config: RunConfig):
    env_hydrated = Bear(**run_config.env)
    env_serialized = env_serializer(env_hydrated)
    logging.info(env_serialized)
    args_serialized = args_serializer(run_config.args)
    logging.info(args_serialized)
    script = run_config.run.format(args=args_serialized, env=env_serialized)
    logging.info(script)
    try:
        check_call(script, shell=True)
    except Exception as e:
        print(e)


def job(run_config: RunConfig):
    logging.info(vars(run_config))

    if 'batch_args' in run_config:
        if ('args' in run_config) and run_config.args:
            logging.warning("both batch_args and args are defined. Only batch_args are used")
        batch_args = run_config.batch_args  # type: list
    else:
        if 'args' not in run_config:
            raise RuntimeError("Neither `batch_args` nor `args` is found in run_config, please check your config file.")
        batch_args = [vars(run_config.args)]  # type: list

    rest_run_config = {k: v for k, v in vars(run_config).items() if k not in ['args', 'batch_args']}
    batch_run_configs = [
        Bear(**rest_run_config, args=args) for args in batch_args
    ]
    p = ProcessingPool(run_config.config.max_concurrent)
    p.map(run, batch_run_configs)


def main():
    # 1. take in yaml file, go through files and run one by one
    yaml = YAML(typ='unsafe', pure=True)

    # noinspection PyUnresolvedReferences
    config_path = Path(Experiment.config_file)
    parsed = yaml.load_all(config_path)
    for run_config in parsed:
        hydrated = DefaultBear(None,
                               **{k: v for k, v in vars(RunConfig).items() if not is_hidden(k)})  # type: RunConfig
        hydrated.update(**run_config)
        job(hydrated)


if __name__ == "__main__":
    main()
