#!python
"""Example script
dave --config-file="experiment.yaml" --starting-index=0 --
"""
import logging
from copy import deepcopy
from pathlib import Path
from subprocess import check_call

from dave import RunTimeParams, hydrate_templates
# noinspection PyUnresolvedReferences
from params_proto import is_hidden, cli_parse, ParamsProto
from pathos.multiprocessing import ProcessingPool
from ruamel.yaml import YAML, sys
from waterbear import DefaultBear, Bear

if sys.version_info >= (3, 6):
    script = '@cli_parse\n' \
             'class Experiment(ParamsProto):\n' \
             '    """Supervised MAML in tensorflow"""\n' \
             '    config_file: "configuration of the experiment" = "./experiment.yml"\n' \
             '    starting_index: "hashed or integer index for the starting experiment" = 0\n' \
             '    dave_log: "log level for logging" = "info"\n'
else:
    script = '@cli_parse\n' \
             'class Experiment(ParamsProto):\n' \
             '    """Supervised MAML in tensorflow"""\n' \
             '    config_file = "./experiment.yml", """configuration of the experiment"""\n' \
             '    starting_index = None, """hashed or integer index for the starting experiment"""\n' \
             '    dave_log = "info", """log level for logging"""\n'
exec(script, globals())


class RunnerConfig:
    max_concurrent = 4  # type: int


class RunConfig(Bear):
    config = DefaultBear(None)  # type: RunnerConfig
    env = {}  # type: dict
    run = 'python main.py {args}'  # type: str
    default_args = {}  # type: dict
    args = {}  # type: dict
    batch_args = []  # type: list


env_serializer = lambda env: " ".join("{}={}".format(k, v) for k, v in env.items())
args_serializer = lambda kwargs: " ".join("--{} {}".format(k.replace('_', '-'), v) for k, v in kwargs.items())


def run(run_config: RunConfig):
    env_serialized = env_serializer(run_config.env)
    hydrate_templates(run_config.args, RunTimeParams())
    args_serialized = args_serializer(run_config.args)
    script = run_config.run.format(args=args_serialized, env=env_serialized)

    try:
        check_call(script, shell=True)
    except Exception as e:
        print(e)

    logging.debug("dave:env_serialized:{}".format(env_serialized), exc_info=False)
    logging.debug("dave:args_serialized:{}".format(args_serialized), exc_info=False)
    logging.debug("dave:job_script:{}".format(script), exc_info=False)


def job(run_config: RunConfig):
    if 'batch_args' in run_config:
        if ('args' in run_config) and run_config.args:
            logging.warning("both batch_args and args are defined. Only batch_args are used")
        batch_args = run_config.batch_args  # type: list
    else:
        if 'args' not in run_config:
            raise RuntimeError("Neither `batch_args` nor `args` is found in run_config, please check your config file.")
        batch_args = [vars(run_config.args)]  # type: list

    with_default = []
    for args in batch_args:
        c = deepcopy(run_config)
        c.args = deepcopy(run_config.default_args)
        c.args.update(**args)
        with_default.append(c)
    p = ProcessingPool(run_config.config.max_concurrent)
    p.map(run, with_default)


def main():
    # 1. take in yaml file, go through files and run one by one
    yaml = YAML(typ='unsafe', pure=True)

    # noinspection PyUnresolvedReferences
    parsed = yaml.load_all(Path(Experiment.config_file))
    for run_config in parsed:
        hydrated = DefaultBear(None,
                               **{k: v for k, v in vars(RunConfig).items() if not is_hidden(k)})  # type: RunConfig
        hydrated.update(**run_config)
        job(hydrated)


if __name__ == "__main__":
    log_level = getattr(logging, Experiment.dave_log.upper(), None)
    logging.getLogger().setLevel(log_level)
    main()
