#!/usr/bin/env python

import wispy.templates
import wispy.utils
import wispy.logger
import os
import argparse
import copy
from tomlkit import dumps
from string import Template
import shutil


def update_nested(dict_a, dict_b):
    """https://stackoverflow.com/a/57969936

    returns dict_a updated in a nested way by dict_b
    uses recursion

    Args:
        dict_a ([dict]): [description]
        dict_b ([dict]): [description]
    """
    set_keys = set(dict_a.keys()).union(set(dict_b.keys()))
    for k in set_keys:
        v = dict_a.get(k)
        if isinstance(v, dict):
            new_dict = dict_b.get(k, None)
            if new_dict:
                update_nested(v, new_dict)
        else:
            new_value = dict_b.get(k, None)
            if new_value:
                dict_a[k] = new_value
    return dict_a


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description="""This script is used to modify a template fit.toml config file
and generate bash run scripts to run the `wispy_autoencoder_fit`
program.

You can use this to do things like try out different optimisers or
batch sizes etc

$ bin/wispy_make_explorer_workflow --config-file explorer.toml --template-config-file wispy_fit_template.toml -v
""",
    )

    parser.add_argument(
        "--config-file", type=str, help="path to explorer config.toml file"
    )

    parser.add_argument(
        "--template-config-file", type=str, help="path to fit template config.toml file"
    )

    parser.add_argument(
        "--output-dir", type=str, help="output directory", default="runs"
    )

    parser.add_argument(
        "-v",
        help="""increase output verbosity
        no -v: WARNING
        -v: INFO
        -vv: DEBUG""",
        action="count",
        dest="verbose",
        default=0,
    )
    parser.add_argument(
        "--force",
        action="store_true",
        help="if given then will not throw an error if output directory exists",
    )

    args = parser.parse_args()
    args_dict = vars(args)

    # https://stackoverflow.com/questions/14097061/easier-way-to-enable-verbose-logging
    level = min(2, args.verbose)  # capped to number of levels
    logger = wispy.logger.init_logger(level=level)

    logger.info(f"wispy version: {wispy.__version__}")

    logger.info("==========")
    logger.info("printing command line args")
    for k in args_dict.keys():
        logger.info(f"{k}: {args_dict[k]}")
    logger.info("==========")

    doc = wispy.utils.read_and_parse_toml(args.config_file)
    template_doc = wispy.utils.read_and_parse_toml(args.template_config_file)

    logger.info("==========")
    logger.info("printing toml config contents")
    wispy.utils.recursive_dict_print(doc, print_fn=logger.info)
    logger.info("==========")

    logger.info(f"making output dir: {args.output_dir}")
    os.makedirs(f"{args.output_dir}", exist_ok=args.force)

    number_of_runs = len(doc["run"])
    logger.info(f"number of runs: {number_of_runs}")

    for i in range(number_of_runs):
        # deep copy imporant here
        logger.info(f"performing nested update: {i}")
        doc["run"][i] = update_nested(copy.deepcopy(template_doc), doc["run"][i])

    logger.info("generating run scripts and writing config files to run dirs")
    run_all_str = wispy.templates.template_bash_shbang

    for i, new_doc in enumerate(doc["run"]):
        rundir = f"run_{i:03}"
        subdir = os.path.join(args.output_dir, rundir)
        os.makedirs(f"{subdir}", exist_ok=args.force)
        filename = os.path.join(subdir, "config.toml")
        logger.info(f"writing: {filename}")
        wispy.utils.write_string_to_file(dumps(new_doc), filename)

        fit_str = Template(wispy.templates.template_wispy_autoencoder_fit).substitute(
            config="config.toml"
        )

        run_i_str = wispy.templates.template_bash_shbang + "\n\n" + fit_str
        outname = os.path.join(subdir, "run.sh")
        wispy.utils.write_string_to_file(run_i_str, outname)
        os.chmod(outname, 0o775)

        run_all_str += f"\npushd {rundir}\n"
        run_all_str += Template(wispy.templates.template_nohup).substitute(
            name="run.sh", logfile="run.log"
        )
        run_all_str += "\npopd\n"

    outname = os.path.join(args.output_dir, "run_all.sh")
    wispy.utils.write_string_to_file(run_all_str, outname)
    os.chmod(outname, 0o775)

    logger.info("finished!")