#!python
from pathlib import Path
import argparse
import nbclient
import nbformat
import nbparameterise
import sys
import os
import time
from collections import defaultdict

parser = argparse.ArgumentParser(description="Notebook Runner", add_help=False)
parser.add_argument("Notebook")
parser.add_argument(
    "--use_notebook_client",
    default=False,
    action="store_true",
    help="Warning: Using this disables stdin",
)
parser.add_argument("--master_debug", default=False, action="store_true", help="")

nboargs, unknown = parser.parse_known_args()

debugMode = nboargs.master_debug
debug_cell_sourcecode = debugMode
debug_print_cell_numbers = debugMode
debug_sleep_between_cells = 0
debug_print_args = debugMode
debug_print_params = debugMode
debug_print_param_params = debugMode
debug_print_cwd = debugMode
if debug_print_cwd:
    print(os.getcwd())

nboargs.Notebook = Path(nboargs.Notebook).resolve()
if debug_print_args:
    print(nboargs)

os.chdir(os.path.dirname(nboargs.Notebook))
sys.path.insert(0, os.path.dirname(nboargs.Notebook))
with open(nboargs.Notebook, "rb") as f:
    nb = nbformat.read(f, as_version=4)

notebook_params = nbparameterise.extract_parameters(nb)
if debug_print_params:
    print("orginal params", notebook_params)


parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
flagglessArgs = {}

requiredNamed = parser.add_argument_group("Required notebook arguments")
optionalNamed = parser
gs = {}
gs[True] = requiredNamed
gs[False] = optionalNamed
p_prams = defaultdict(
    lambda: {"type": None, "required": False, "toLower": False, "flagglessArgs": False}
)


for item in notebook_params:
    name = item.name
    p_prams[name]["type"] = item.type
    comment = ""
    value = item.value
    if item.comment is not None:
        commentlist = item.comment.split("#")
        commentlist.reverse()
        while len(commentlist):
            comment = commentlist.pop(0)
            if comment.lower().endswith("required"):
                p_prams[name]["required"] = True

            elif comment.lower().endswith("tolower"):
                p_prams[name]["toLower"] = True
            else:
                commentlist.reverse()
                comment = "#".join(commentlist)
                break

    if item.name.startswith("__") and item.name.endswith("__"):
        name = name.replace("__", "")
        p_prams[name]["flagglessArgs"] = item.name
        if item.type == list:
            print(item)
            parser.add_argument(f"{name}", choices=value, help=f"{comment}")
        else:
            parser.add_argument(
                f"{name}",
                nargs="?",
                default=value,
                type=item.type,
                help=f"{comment}  (default: %(default)s)",
            )

    else:
        if p_prams[name]["required"]:
            flag = f"-{name}"
            if value == "":
                value = None
        else:
            flag = f"--{name}"

        if item.type == list:
            gs[p_prams[name]["required"]].add_argument(
                f"{flag}",
                choices=value,
                help=f"{comment}",
                required=p_prams[name]["required"],
            )
        elif item.type == bool:
            gs[p_prams[name]["required"]].add_argument(
                f"{flag}",
                default=value,
                choices=["True", "False"],
                help=f"{comment}",
                required=p_prams[name]["required"],
            )
        else:
            gs[p_prams[name]["required"]].add_argument(
                f"{flag}",
                nargs="?",
                default=value,
                type=item.type,
                help=f"{comment}  (default: %(default)s)",
                required=p_prams[name]["required"],
            )
if debug_print_param_params:
    print("p_prams", p_prams)
args = sys.argv
if debug_print_args:
    print(args)
args.pop(0)
args.pop(0)
if debug_print_args:
    print(args)
args, unknown = parser.parse_known_args(args)
if debug_print_args:
    print("NBARGS", args, unknown)


final_args = {}
args = vars(args)
for a in args:
    if p_prams[a]["flagglessArgs"]:
        final_args[f"__{a}__"] = args[a]
    else:
        final_args[a] = args[a]


new_params = nbparameterise.parameter_values(notebook_params, **final_args)
if debug_print_params:
    print("replacement params", new_params)
for item in new_params:
    if p_prams[item.name]["type"] == bool:
        item.value = "True" == item.value
new_nb = nbparameterise.replace_definitions(nb, new_params)


if not nboargs.use_notebook_client:
    for cell_index, cell in enumerate(new_nb["cells"]):
        if cell["cell_type"] == "code":
            if debug_cell_sourcecode:
                print(cell)
            if debug_print_cell_numbers:
                print(f"[{cell_index}]")
            code = compile(cell["source"], f"cell {cell_index}", "exec")
            exec(code)
            if debug_print_cell_numbers:
                print()
            if debug_sleep_between_cells > 0:
                time.sleep(debug_sleep_between_cells)


else:
    print("Using Notebook client")
    output = nbclient.execute(new_nb)
    for item in output["cells"]:
        for o in item["outputs"]:
            print(o["text"].strip())
