#!python

import sys
import os
import argparse
import yaml

from npu_compiler import VERSION, NpuVersionManager

def get_config_from_file(config_file, quant):
    if isinstance(config_file, str):
        # input from yaml file
        try:
            with open(config_file) as f:
                config_dict = yaml.safe_load(f)
                config_dict["IS_QUANT"] = quant
                config_dir = os.path.dirname(config_file)
        except IOError:
            print("[ERROR] can't open config file: \"%s\"" % config_file)
            sys.exit(1)
    else:
        # input from stdin, e.g. `cat xxx.yaml|gxnpuc`
        config_dict = yaml.safe_load(sys.stdin)
        config_dict["IS_QUANT"] = quant
        config_dir = ""
    return config_dir, config_dict


if __name__ == "__main__":
    parser = argparse.ArgumentParser(prog="gxnpuc", description="NPU Compiler")
    parser.add_argument("--cmpt", action="store_true",\
            help="get version compatibility information between npu core, python, and frameworks")
    parser.add_argument("--list", action="store_true", help="list supported operators")
    parser.add_argument("-c", "--core_name", default="ALL",\
            choices=NpuVersionManager.SUPPORT_CORE_VERSION,\
            help="subargument to '--list', which is used to specify NPU Core when listing supported operators")
    parser.add_argument("-f", "--framework", default="ALL",\
            choices=NpuVersionManager.SUPPORT_FRAMEWORKS,\
            help="subargument to '--list', which is used to specify the type of deep learning framework when listing supported operators")
    parser.add_argument("-V", "--version", action="version", version="gxnpuc %s" % VERSION)
    parser.add_argument("-v", "--verbose", action="store_true", help="verbosely list the NPU model structure infomation")
    parser.add_argument("-m", "--meminfo", action="store_true", help="verbosely list memory infomation of operators")
    parser.add_argument("-w", "--weights", action="store_true", help="print compressed weights (GRUS only)")
    parser.add_argument("-s", "--save_hist", action="store_true", help="save histograms of weights value to 'npu_jpgs' directory (GRUS only)")
    parser.add_argument("-q", "--quant", action="store_true", help="inference and generate quant file (%s)"\
            % NpuVersionManager.get_core_with_quant())
    parser.add_argument("config_filename", nargs="?", default=sys.stdin, help="config file")

    args = parser.parse_args()
    if args.cmpt:
        print(NpuVersionManager.get_env_compatibility_info())
        sys.exit(0)
    elif args.list:
        ops_table_dict = NpuVersionManager.get_ops_table_dict()
        if args.framework != "ALL":
            supported_frameworks = list(ops_table_dict[args.core_name].keys())
            supported_frameworks.remove("ALL")

            if args.framework not in supported_frameworks:
                print("[ERROR] The --frameworks/-f parameter values are incorrectly configured.")
                print("        NPU %s Compiler does not support handling %s framework models in python%s environment!"\
                        % (args.core_name, args.framework, sys.version_info[0:2]))
                print("        When the '--core_name/-c' parameter value is configured to %s, the optional value for '--frameworks/-f' parameter is: %s."\
                        % (args.core_name if args.core_name != "ALL" else "default", supported_frameworks))
                print("\nTry using the 'gxnpuc --cmpt' command to get version compatibility information between npu core, python, and frameworks!")
                sys.exit(0)

        for table in ops_table_dict[args.core_name][args.framework]:
            print(table.get_ops_note())

        sys.exit(0)
    elif args.config_filename:
        config_dir, config_dict = get_config_from_file(args.config_filename, args.quant)
    else:
        parser.print_help()
        sys.exit(0)

    config_para = {"VERBOSE": args.verbose, "MEMINFO": args.meminfo, "PRINT_WEIGHTS": args.weights,\
            "SAVE_HIST": args.save_hist, "CONFIG_DIR": config_dir}
    corename = config_dict.get("CORENAME", "")
    core_funcs = NpuVersionManager.get_npu_funcs_dict().get(corename, {})
    if not core_funcs:
        print("CORENAME is not supported!")
        sys.exit(1)

    core_funcs.get("load")(config_dict, config_para)
    if args.quant:
        if not core_funcs.get("quant"):
            print("CORENAME '%s' doesn't need to run quantized inference" % corename)
            sys.exit(1)
        core_funcs.get("quant")()
    else:
        core_funcs.get("run")()
