#!/usr/bin/env python

import os
import sys
import shutil
import argparse
import subprocess
from pathlib import Path

class SetupError(Exception):
    pass

def run_build(script_path, device, pt, tf, onnx, verbose):
    script = script_path.joinpath("smartsim_setup")
    if not script.is_file():
        msg = "Could not find smartsim_setup script. "
        msg += "You may have to trigger manually from the installation site"
        raise SetupError(msg)

    if sys.platform == "darwin":
        if device.lower() == "gpu":
            raise SetupError("SmartSim does not support GPU builds on MacOS")
        if onnx == 1:
            raise SetupError("ONNX runtime is currently not supported on MacOS")

    if sys.platform in ["msys", "win32", "cygwin"]:
        msg = "Windows is not supported, but kudos to you for making it this far!"
        raise SetupError(msg)

    print("Running SmartSim build process...")
    cmd = [f"PT={pt} TF={tf} TFL=0 ONNX={onnx}", str(script), device]
    cmd = " ".join(cmd)
    if verbose:
        subprocess.check_call(cmd, shell=True)
    else:
        proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        out, err = proc.communicate()
        returncode = int(proc.returncode)
        if returncode != 0:
            error = f"SmartSim setup failed with exitcode {returncode}\n"
            error += err.decode("utf-8")
            raise SetupError(error)
        else:
            print("SmartSim setup complete!")


def clean(install_path, _all=False):
    """Remove pre existing installations of ML runtimes

    :param lib_path: path to installation
    :type lib_path: str
    :param _all: Remove all non-python dependencies
    :type _all: bool, optional
    """
    build_temp = install_path.joinpath(".third-party")
    if build_temp.is_dir():
        shutil.rmtree(build_temp, ignore_errors=True)
    lib_path = install_path.joinpath("lib")
    if lib_path.is_dir():
        # remove RedisAI
        rai_path = lib_path.joinpath("redisai.so")
        if rai_path.is_file():
            rai_path.unlink()
            print("Succesfully removed existing RedisAI installation")

        backend_path = lib_path.joinpath("backends")
        if backend_path.is_dir():
            shutil.rmtree(backend_path, ignore_errors=True)
            print("Successfully removed ML runtimes")

        # remove redisip
        if _all:
            ip_path = lib_path.joinpath("libredisip.so")
            if ip_path.is_file():
                ip_path.unlink()
                print("Succesfully removed existing RedisIP installation")

    bin_path = install_path.joinpath("bin")
    if bin_path.is_dir() and _all:
        files_to_remove = ["redis-server", "redis-cli"]
        removed = False
        for _file in files_to_remove:
            file_path = bin_path.joinpath(_file)
            if file_path.is_file():
                removed = True
                file_path.unlink()
        if removed:
            print("Successfully removed SmartSim Redis installation")


def colorize(string, color, bold=False, highlight=False):
    """
    Colorize a string.
    This function was originally written by John Schulman.
    And then borrowed from spinningup
    https://github.com/openai/spinningup/blob/master/spinup/utils/logx.py
    """

    color2num = dict(
    gray=30,
    red=31,
    green=32,
    yellow=33,
    blue=34,
    magenta=35,
    cyan=36,
    white=37,
    crimson=38,
    )

    attr = []
    num = color2num[color]
    if highlight:
        num += 10
    attr.append(str(num))
    if bold:
        attr.append("1")
    return "\x1b[%sm%s\x1b[0m" % (";".join(attr), string)

def make_description():
    header = "\nSmartSim Setup\n"
    header += "--------------\n\n"
    header = colorize(header, bold=True, color="cyan")

    info = "This script will setup the dependencies for SmartSim\n"
    info += "If compatible, SmartSim will download and install\n"
    info += "shared libraries for requested ML runtimes.\n\n"

    info += "If the system OS is not supported, or the user\n"
    info += "asks for backends that are not pre-built, this\n"
    info += "script will attempt to build the backends.\n\n"
    info = colorize(info, color="green")

    other = "If installing from source, this script will also\n"
    other += "build and install other non-python dependencies\n"
    other += "needed by SmartSim."

    prebuilts = colorize("Supported Systems\n", color="cyan")
    systems = "  1) Linux CPU (x86_64)\n"
    systems += "  2) MacOS CPU (x86_64)\n"
    systems += "  3) Linux GPU (CUDA 10.2 or greater)\n\n"
    systems = colorize(systems, color="green")

    default = "By default, the PyTorch and TensorFlow backends will\n"
    default += "be installed.\n"
    default = colorize(default, color="green")

    return "".join((header, info, prebuilts, systems, default))



def cli():

    # Try to import SmartSim
    try:
        import smartsim as _
    except (ImportError, ModuleNotFoundError):
        raise SetupError("Could not import SmartSim") from None

    # find the path to the setup script
    package_path = Path(_.__path__[0]).resolve()
    if not package_path.is_dir():
        raise SetupError("Could not find SmartSim installation site")

    script_path = package_path.joinpath("bin/")

    parser = argparse.ArgumentParser(description="SmartSim Setup")
    parser.add_argument('-v', action="store_true", default=False,
                        help='Enable verbose build process')
    parser.add_argument('--clean', action="store_true", default=False,
                        help='Remove previous ML runtime installation')
    parser.add_argument('--clobber', action="store_true", default=False,
                        help='Remove all SmartSim non-python dependencies to build from scratch')
    parser.add_argument('--device', type=str, default="cpu",
                        help='Device to build ML runtimes for (cpu || gpu)')
    parser.add_argument('--no_pt', action="store_true", default=False,
                        help='Do not build PyTorch backend')
    parser.add_argument('--no_tf', action="store_true", default=False,
                        help='Do not build TensorFlow backend')
    parser.add_argument('--onnx', action="store_true", default=False,
                        help='Build ONNX backend (off by default)')

    # display help
    if len(sys.argv) < 2:
        print(make_description())
        parser.print_help()
    elif sys.argv[1] in ["-h", "h", "--help", "help"]:
        print(make_description())
        parser.print_help()
    else:
        args = parser.parse_args()
        # clean previous installations
        if args.clobber:
            clean(package_path, _all=True)
            exit(0)
        elif args.clean:
            clean(package_path, _all=False)
            exit(0)

        def color(is_in_build=True):
            _color = "green" if is_in_build else "red"
            return colorize(str(is_in_build), color=_color)

        # decide which runtimes to build
        print("\nBackends Requested")
        print("-------------------")
        print(f"    PyTorch: {color(not args.no_pt)}")
        print(f"    TensorFlow: {color(not args.no_tf)}")
        print(f"    ONNX: {color(args.onnx)}")
        print("\n")
        pt = 0 if args.no_pt else 1
        tf = 0 if args.no_tf else 1
        onnx = 1 if args.onnx else 0
        run_build(script_path, args.device, pt, tf, onnx, args.v)

if __name__ == '__main__':
    cli()