#!/usr/bin/env python3
#!/usr/bin/python3

from posixpath import join
from panoramacli import __version__
import sys
import os
import subprocess
import json
import uuid
import time
import shutil
import argparse
import hashlib
import time
import re
import pkg_resources

# Static values
RUNNING_APP_LIST = 'running_apps'
DOCKERFILE_PATH = 'dockerfile_path'
MODEL_DIR = 'model_dir'
SCRIPT_PATH = 'script_path'
COMPILE_ROLE_ARN = 'role_arn_for_model_compilation'
COMPILE_INPUT_URI = 'model_s3_input_uri'
INPUT_SHAPE = 'input_shape'
FRAMEWORK = 'framework'
COMPILE_OUTPUT_URI = 'model_s3_output_uri'

pypi_package_name = "panoramacli"
VAR_PATH = 'var.json'
PROJECT_SKELETON_PATH = 'resources/project_skeleton'
PACKAGE_SKELETON_PATH = 'resources/package_skeleton'
GRAPH_TEMPLATE_PATH = 'resources/graph_template.json'
PACKAGE_TEMPLATE_PATH = 'resources/package_template.json'
DOCKERFILE_PATH = 'resources/Dockerfile'
DESCRIPTOR_CONTAINER_TEMPLATE_PATH = 'resources/descriptor_container_template.json'
DESCRIPTOR_MODEL_TEMPLATE_PATH = 'resources/descriptor_model_template.json'
DOWNLOAD_ASSET_TEMPLATE_PATH = 'resources/download_model.json'
CAMERA_PACAKAGE_TEMPLATE_PATH='resources/camera_package.json'
DATASINK_PACAKAGE_TEMPLATE_PATH='resources/datasink_package.json'
BUILD_ASSET_TEMPLATE_PATH = 'resources/build_package.json'
GRAPH_NODE_TEMPLATE_PATH = 'resources/graph_node_template.json'
PACKAGE_INTERFACE_TEMPLATE_PATH = 'resources/package_interface_template.json'
SERVICE_CLI_APIS = 'resources/OmniCloudServiceLambda.api.json'
JOB_ID = 'job_id'
COMPILATION_JOB_STATUS = 'CompilationJobStatus'
FAILURE_REASON = 'FailureReason'

CLI_DIR = os.path.dirname(os.path.realpath(__file__))

# This is subject to change once there is jetpack upgrade.
COMPILER_OPTIONS = {
    "gpu-code": "sm_72",
    "trt-ver": "7.1.3",
    "cuda-ver": "10.2"
}

TARGET_PLATFORM = {
    "Os": "LINUX",
    "Arch": "ARM64",
    "Accelerator": "NVIDIA"
}

STOPPING_CONDITION = {
    "MaxRuntimeInSeconds": 900,
    "MaxWaitTimeInSeconds": 60
}


# Utility classes
class InputConfig:
    def __init__(self, s3uri, framework, input_shape):
        self.s3uri = s3uri
        self.framework = framework
        self.input_shape = input_shape

    # TODO: use package for serialization.
    def __str__(self):
        res = '\"{'

        res += '\\\"S3Uri\\\": \\\"' + self.s3uri + '\\\", '

        res += '\\\"DataInputConfig\\\": \\\"{'
        res += serialize_dict_with_quote({"input": self.input_shape}, 1)
        res += '}\\\", '

        res += '\\\"Framework\\\": \\\"' + self.framework + '\\\"'

        res += '}\"'
        return res


class OutputConfig:
    def __init__(self, s3uri, compiler_options, target_platform):
        self.s3uri = s3uri
        self.compiler_options = compiler_options
        self.target_platform = target_platform

    def __str__(self):
        res = '\"{'
        res += '\\\"S3OutputLocation\\\": \\\"' + self.s3uri + '\\\", '

        res += '\\\"TargetPlatform\\\": {'
        res += serialize_dict_with_quote(self.target_platform, 2)
        res += '}, '

        res += '\\\"CompilerOptions\\\": \\\"{'
        res += serialize_dict_with_quote(self.compiler_options, 3)
        res += '}\\\"'

        res += '}\"'
        return res

def json_load_decorator(json_load_func):
    def catch_json_load_errors(f):
        try:
            return json_load_func(f)
        except Exception as e:
            print(e)
            print("JSON formatting error in ", f.name)
            print("Please fix the error and try again")
            sys.exit(0)
    return catch_json_load_errors

# Utility functions
def execute(cmd, runtime_print_output=False, logger=None):
    if not runtime_print_output:
        proc = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
        stdout = proc.stdout.decode('utf-8')
        if logger is not None:
            logger.debug('Executing: ' + cmd)
            logger.debug(stdout)
    else:
        proc = subprocess.run(cmd, shell=True)
        stdout = None
    return proc.returncode, stdout

def serialize_dict_with_quote(dict_to_prt, option):
    res = ''
    for key in dict_to_prt:
        if option == 1:  # \\\"key\\\": value
            res += ('\\\\\\"' + key + '\\\\\\": ')
            res += str(dict_to_prt[key]) + ', '
        elif option == 2:  # \"key\": \"value\"
            res += ('\\\"' + key + '\\\": ')
            res += '\\\"' + str(dict_to_prt[key]) + '\\\", '
        elif option == 3:  # \\\"key\\\": \\\"value\\\"
            res += ('\\\\\\"' + key + '\\\\\\": ')
            res += '\\\\\\"' + str(dict_to_prt[key]) + '\\\\\\", '
        elif option == 4:  # \"key\": value
            res += ('\\\"' + key + '\\\": ')
            res += str(dict_to_prt[key]) + ', '
    if len(res) >= 2:
        res = res[:-2]
    return res


def get_model_compilation_info(job_id):
    role_arn = args.compile_role_arn
    framework = args.framework
    s3uri_input = args.model_s3_input_uri
    input_shape = args.input_shape
    input_config = InputConfig(s3uri_input, framework, input_shape)
    print('Input config is: ' + str(input_config))

    s3uri_output = args.model_s3_output_uri + job_id
    output_config = OutputConfig(s3uri_output, COMPILER_OPTIONS, TARGET_PLATFORM)
    print('Output config is: ' + str(output_config) + '\n')

    return role_arn, input_config, output_config


def get_compilation_details():
    try:
        with open(VAR_PATH, 'r') as f:
            var = json.load(f)
    except FileNotFoundError:
        sys.exit('var.json not found in current directory. Make sure you run model compilation first.')
    job_id = var[JOB_ID]
    output_uri = var[COMPILE_OUTPUT_URI] + job_id + '/'
    return job_id, output_uri


def get_compilation_job_response(job_id):
    cmd = 'aws sagemaker describe-compilation-job --output json --compilation-job-name ' + job_id
    return_code, out = execute(cmd)
    if return_code != 0:
        print(out)
        sys.exit('Error when fetching compilation job status.')
    resp = json.loads(out)
    return resp

def get_hash(name):
    return hashlib.sha256(name.encode('utf-8')).hexdigest()

def create_ext4_fs_image(tar_path, image_name):
    return_code, out = execute(["dd if=/dev/zero of=" + image_name +" bs=512M count=1"])
    if return_code != 0:
        print(out)
        sys.exit('Error while creating an empty disk image')
    return_code, empty_loop_device = execute(["losetup -f"])
    empty_loop_device = empty_loop_device.rstrip()
    mount_path = "/mnt/" + image_name.split('.')[0]
    if return_code != 0:
        print(out)
        sys.exit('Error while trying to get empty loop device')
    commands = ["losetup --find --show " + image_name + " " + empty_loop_device, "mkfs.ext4 " + empty_loop_device, "mkdir -p " + mount_path, "mount " + empty_loop_device + " " + mount_path, "cp " + tar_path + " " + mount_path, "tar -xvf " + mount_path + "/*.tar.gz -C " + mount_path, "rm " + mount_path + "/*.tar.gz", "umount " + mount_path, "losetup -d " + empty_loop_device, "rm -rf /mnt/dx_tar_file"]
    for cmd in commands:
        print(cmd)
        return_code, out = execute([cmd])
        if return_code != 0:
            print(out)
            sys.exit('Error while creating a fs image from tar file')
    print("Created image " + image_name + " from the tar file")

def create_squash_fs_image(tar_path, image_name):
    tar_dir = os.path.dirname(tar_path)
    commands = ["tar -xvf " + tar_path + " --directory " + tar_dir, "rm -rf " + tar_path, "mksquashfs " + tar_dir + ' ' +image_name, "truncate -s +3M " + image_name, "rm -rf " + tar_dir]
    for cmd in commands:
        print(cmd)
        return_code, out = execute([cmd])
        if return_code != 0:
            print(out)
            sys.exit('Error while creating a fs image from tar file')
    print("Created image " + image_name + " from the tar file")

def create_tar_asset(tar_path, tar_name):
    tar_dir = os.path.dirname(tar_path)
    commands = ["tar -xvf " + tar_path + " --directory " + tar_dir, "rm -rf " + tar_path, "tar -czvf " + tar_name + " " + tar_dir, "rm -rf " + tar_dir]
    for cmd in commands:
        print(cmd)
        return_code, out = execute([cmd])
        if return_code != 0:
            print(out)
            sys.exit('Error while creating a tar asset')
    print("Created asset  " + tar_name)

def get_file_sha_hash(file_path):
    file_hash = ""
    with open(file_path,"rb") as f:
        bytes = f.read()
        file_hash = hashlib.sha256(bytes).hexdigest()
    return file_hash

def get_aws_account_id():
    status, output = execute(['aws sts get-caller-identity --query Account --output text'])
    if status != 0:
        print(output)
        sys.exit('Error getting AWS account ID, please configure your account before proceeding')
    account_id = re.sub('[^A-Za-z0-9]+', '', output) #Removing all special characters from stdout
    return account_id

def throw_error(status, error, error_message, exit=True):
    if error:
        print(error)
        if "Could not connect" in error:
            print("Network issue while connecting to the endpoint, check your connection and try again")
    if exit:
        sys.exit(error_message)
    else:
        print(error_message)

def verify_cwd_is_project_root(func):
    def cwd_check_wrapper():
        application_root_dirs = ["assets", "graphs", "packages"]
        cwd_dirs = os.listdir(os.getcwd())
        for dir in application_root_dirs:
            if dir not in cwd_dirs:
                sys.exit("panorama-cli can only be used from application root directory, cd to application root and try again")
        func()
    return cwd_check_wrapper

def get_absolute_path(path):
    if os.path.isabs(path):
        absolute_path = path
    else:
        absolute_path = os.path.join(os.getcwd(), path)
    return absolute_path

def get_path_with_espace_characters(path):
    return path.replace(" ", "\\ ").replace("?", "\\?").replace("&", "\\&").replace("(", "\\(").replace(")", "\\)").replace("*", "\\*").replace("<", "\\<").replace(">", "\\>")

@verify_cwd_is_project_root
def get_graph_json_path():
    graphs_path = os.path.join(os.getcwd(), "graphs")
    graph_dir = [a for a in os.listdir(graphs_path) if os.path.isdir(os.path.join(graphs_path, a))][0]
    graph_json_path = os.path.join(graphs_path, graph_dir, "graph.json")
    print(graph_json_path)
    return graph_json_path

def add_graph_node(node_name, package_interface):
    graphs_path = os.path.join(os.getcwd(), "graphs")
    graph_dir = [a for a in os.listdir(graphs_path) if os.path.isdir(os.path.join(graphs_path, a))][0]
    graph_json_path = os.path.join(graphs_path, graph_dir, "graph.json")
    with open(graph_json_path, "r+") as f:
        graph_json = json.load(f)
    
    node_template = json.loads(pkg_resources.resource_string(pypi_package_name, GRAPH_NODE_TEMPLATE_PATH))
    node_template["name"] = node_name
    node_template["interface"] = package_interface
    graph_json["nodeGraph"]["nodes"].append(node_template)

    with open(graph_json_path, "r+") as f:
        json.dump(graph_json, f, indent=4)

def interface_exists(interfaces, asset_name):
    for interface in interfaces:
        if interface["asset"] == asset_name:
            return True
    return False

def get_package_name(package_path):
    package_dir = os.path.basename(os.path.normpath(package_path))
    package_name = '-'.join(package_dir.split('-')[1:-1]) #TODO Too hacky right now, improve the logic. Extracts package_name, for example, people-counter-package from 619501627742-people-counter-package-1.0
    return package_name

def build_package(docker_build=True):

    asset_name = args.container_asset_name
    asset_path = os.path.join(os.getcwd(), asset_name + ".tar")

    package_src_path = os.path.join(args.package_path,'src')
    if os.path.exists(asset_path):
        os.remove(asset_path)
    
    if docker_build:
        commands = ["TMPDIR=$(pwd) docker build -t " + asset_name + " " + args.package_path + " --pull", "docker export --output=" + asset_name + ".tar $(docker create " + asset_name + ":latest)"]
    else:
        commands = ["docker export --output=" + asset_name + ".tar $(docker create " + args.container_image_uri + ")"]
    for cmd in commands:
        print(cmd)
        return_code, out = execute([cmd], runtime_print_output=True)
        if return_code != 0:
            throw_error(return_code, out, 'Error while creating and exporting a docker image')

    descriptor_path = os.path.join(args.package_path, "descriptor.json")
    descriptor_uri = get_file_sha_hash(descriptor_path) + ".json"
    descriptor_dst_final_path = os.path.join(os.getcwd(), "assets", descriptor_uri)
    shutil.copyfile(descriptor_path, descriptor_dst_final_path)

    image_name = asset_name + ".tar"
    commands= ["gzip -9 " + image_name]
    for cmd in commands:
        print(cmd)
        return_code, out = execute([cmd], runtime_print_output=True)
        if return_code != 0:
            throw_error(return_code, out, 'Error while compressing the exported docker tar file')

    image_name = asset_name + ".tar.gz" #tar.gz here because we just gzip'ed the tar file generated by docker export above 
    image_src_path = os.path.join(os.getcwd(), image_name)
    image_name = get_file_sha_hash(image_src_path) + ".tar.gz"
    image_dst_path = os.path.join(os.getcwd(), "assets/" + image_name)
    shutil.move(image_src_path, image_dst_path)

    asset_template = json.loads(pkg_resources.resource_string(pypi_package_name, BUILD_ASSET_TEMPLATE_PATH))
    asset_template["name"] = asset_name
    asset_template["implementations"][0]["assetUri"] = image_name
    asset_template["implementations"][0]["descriptorUri"] = descriptor_uri
    
    interface_template = json.loads(pkg_resources.resource_string(pypi_package_name, PACKAGE_INTERFACE_TEMPLATE_PATH))
    interface_template["name"] = asset_name + "_interface"
    interface_template["category"] = "business_logic"
    interface_template["asset"] = asset_name
    
    package_json_path = os.path.join(args.package_path, "package.json")
    assets_dir = os.path.join(os.getcwd(), 'assets')
    with open(package_json_path, "r+") as f:
        package_json = json.load(f)
        assets = [asset_template]
        new_asset = True
        for asset in package_json["nodePackage"]["assets"]:
            if asset["name"] == asset_name:
                print("Updating an existing asset with the same name")
                new_asset = False
                for implementation in asset["implementations"]:
                    asset_uri = implementation["assetUri"]
                    old_descriptor_uri = implementation["descriptorUri"]
                    if asset_uri != image_name:
                        print("Deleting old asset " + asset_uri)
                        asset_path = os.path.join(assets_dir, asset_uri)
                        if os.path.exists(asset_path):
                            os.remove(asset_path)
                    if old_descriptor_uri != descriptor_uri:
                        print("Deleting old descriptor " + old_descriptor_uri)
                        descriptor_path = os.path.join(assets_dir, old_descriptor_uri)
                        if os.path.exists(descriptor_path):
                            os.remove(descriptor_path)
            else:
                assets.append(asset)
        package_json["nodePackage"]["assets"] = assets
        if new_asset and not interface_exists(package_json["nodePackage"]["interfaces"], asset_name):
            package_json["nodePackage"]["interfaces"].append(interface_template)
            default_node_name = asset_name + "_node"
            account_id = get_aws_account_id()
            full_interface_path = account_id + "::" + get_package_name(args.package_path) + "." + interface_template["name"]
            add_graph_node(default_node_name, full_interface_path)
        f.seek(0)
        json.dump(package_json, f, indent=4)
        f.truncate()
    print(json.dumps(asset_template, indent=4))
    print("Container asset for the package has been succesfully built at ", image_dst_path)

def compile_model():
    job_id = str(uuid.uuid4()).split('-')[0]
    role_arn, input_config, output_config = get_model_compilation_info(job_id)

    str_stop_condition = '\"{' + serialize_dict_with_quote(STOPPING_CONDITION, 4) + '}\"'

    cmd = 'aws sagemaker create-compilation-job --compilation-job-name ' + job_id + ' --role-arn ' + role_arn + \
        ' --input-config ' + str(input_config) + ' --output-config ' + str(output_config) + ' --stopping-condition ' + \
        str_stop_condition
    print('Calling Sagemaker API: ' + cmd)
    return_code, out = execute(cmd)
    print(out)
    if return_code == 0:
        print('Successfully created model compilation job with id: ' + job_id)
        dict_job_id = {JOB_ID: job_id, COMPILE_OUTPUT_URI: args.model_s3_output_uri}
        with open(VAR_PATH, 'w+') as f:
            json.dump(dict_job_id, f, indent=4)
    else:
        print('Compilation job creation failure.')

@verify_cwd_is_project_root
def build():
    build_package()

def export():
    build_package(docker_build=False)

@verify_cwd_is_project_root
def download_raw_model():
    print("download_raw_model command is deprecated, use add_raw_model instead")

@verify_cwd_is_project_root
def add_raw_model():
    file_base_name = args.model_asset_name
    asset_dir = os.path.join(os.getcwd(), "assets")
    dst_model_path = os.path.join(asset_dir, file_base_name + '.tar.gz')

    if args.model_s3_uri:
        s3_uri = args.model_s3_uri
        cmd = 'aws s3 cp ' + s3_uri + ' ' + dst_model_path

        return_code, out = execute(cmd, runtime_print_output=True)
        if return_code != 0:
            throw_error(return_code, out, 'Error when downloading compiled artifacts (' + s3_uri + ') to ' + asset_dir)
    elif args.model_local_path:
        src_model_path = get_absolute_path(args.model_local_path)
        shutil.copy(src_model_path, dst_model_path)

    try:
        descriptor_src_path = os.path.join(os.getcwd(), args.descriptor_path)
        descriptor_uri = get_file_sha_hash(descriptor_src_path) + ".json"
        descriptor_dst_path = os.path.join(asset_dir, descriptor_uri)
        shutil.copyfile(descriptor_src_path, descriptor_dst_path)
    except Exception as e:
        print(e)
        print("Error while reading descriptor, check --descriptor-path to make sure descriptor path is correct")
        sys.exit(0)
    
    model_tar_name = get_file_sha_hash(dst_model_path) + ".tar.gz"
    final_model_dst_path = os.path.join(os.getcwd(), "assets", model_tar_name)
    shutil.move(dst_model_path, final_model_dst_path)
    
    asset_template = json.loads(pkg_resources.resource_string(pypi_package_name, DOWNLOAD_ASSET_TEMPLATE_PATH))
    asset_template["name"] = args.model_asset_name
    asset_template["implementations"][0]["assetUri"] = model_tar_name
    asset_template["implementations"][0]["descriptorUri"] = descriptor_uri
    
    interface_template = json.loads(pkg_resources.resource_string(pypi_package_name, PACKAGE_INTERFACE_TEMPLATE_PATH))
    interface_template["name"] = args.model_asset_name + "_interface"
    interface_template["category"] = "ml_model"
    interface_template["asset"] = args.model_asset_name

    if args.packages_path:
        assets_dir = os.path.join(os.getcwd(), 'assets')
        for package_path in args.packages_path:
            package_json_path = os.path.join(package_path, "package.json")
            with open(package_json_path, "r+") as f:
                package_json = json.load(f)
                assets = [asset_template]
                new_asset = True
                for asset in package_json["nodePackage"]["assets"]:
                    if asset["name"] == args.model_asset_name:
                        new_asset = False
                        print("Updating an existing asset with the same name")
                        for implementation in asset["implementations"]:
                            asset_uri = implementation["assetUri"]
                            old_descriptor_uri = implementation["descriptorUri"]
                            if asset_uri != model_tar_name:
                                print("Deleting old asset " + asset_uri)
                                asset_path = os.path.join(assets_dir, asset_uri)
                                if os.path.exists(asset_path):
                                    os.remove(asset_path)
                            if old_descriptor_uri != descriptor_uri:
                                print("Deleting old descriptor " + old_descriptor_uri)
                                descriptor_path = os.path.join(assets_dir, old_descriptor_uri)
                                if os.path.exists(descriptor_path):
                                    os.remove(descriptor_path)
                    else:
                        assets.append(asset)
                package_json["nodePackage"]["assets"] = assets
                if new_asset and not interface_exists(package_json["nodePackage"]["interfaces"], args.model_asset_name):
                    package_json["nodePackage"]["interfaces"].append(interface_template)
                    default_node_name = args.model_asset_name
                    account_id = get_aws_account_id()
                    full_interface_path = account_id + "::" + get_package_name(package_path) + "." + interface_template["name"]
                    add_graph_node(default_node_name, full_interface_path)
                f.seek(0)
                json.dump(package_json, f, indent=4)
                f.truncate()
    else:
        print("Copy the following in the assets section of package.json")
    print(json.dumps(asset_template, indent=4))
    print('Successfully downloaded the model to ' + final_model_dst_path)

@verify_cwd_is_project_root
def download_compiled_model():
    file_base_name = args.model_asset_name
    asset_dir = "./assets/" + file_base_name + "/"
    model_dir = asset_dir + file_base_name + '.tar.gz'
    if not args.model_s3_uri:
        job_id, s3_uri = get_compilation_details()
        print('Getting response for job_id ' + job_id)
        resp = get_compilation_job_response(job_id)
        # Neo API doc: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DescribeCompilationJob.html
        status = resp[COMPILATION_JOB_STATUS]
        if status == 'STARTING' or status == 'INPROGRESS':
            print('Waiting for compilation job ' + job_id + ' to be completed. This may take a few minutes. '
                                                            'Please call again later.')
            print('Current compilation job status: ' + status)
            return
        elif status != 'COMPLETED':
            print('Compilation job terminated with status: ' + status)
            print('Failure reason: ' + resp[FAILURE_REASON])
            return
        print('Compilation job completed. Downloading model artifacts...')
        cmd = 'aws s3 ls ' + s3_uri
        return_code, out = execute(cmd)
        if return_code != 0:
            throw_error(return_code, out, 'Error when fetching compiled model in S3.')
        model_artifacts = out.split()[-1]
        s3_uri = os.path.join(s3_uri, model_artifacts)
        cmd = 'aws s3 cp ' + s3_uri + ' ' + model_dir
        os.remove(VAR_PATH)
    else:
        s3_uri = args.model_s3_uri
        cmd = 'aws s3 cp ' + s3_uri + ' ' + model_dir

    start = time.time()
    return_code, out = execute(cmd, runtime_print_output=True)
    print("Time taken to download ", time.time() - start)
    if return_code != 0:
        throw_error(return_code, out, 'Error when downloading compiled artifacts (' + s3_uri + ') to ' + model_dir)
    print('Successfully downloaded compiled artifacts (' + s3_uri + ') to ' + model_dir)

    descriptor_src_path = os.path.join(CLI_DIR, DESCRIPTOR_MODEL_TEMPLATE_PATH)
    descriptor_dst_path = os.path.join(asset_dir, "descriptor.json")
    shutil.copyfile(descriptor_src_path, descriptor_dst_path)

    start = time.time()
    
    if args.ext4fs:
        model_img_name = get_hash(file_base_name) + ".img"
        create_ext4_fs_image(model_dir, model_img_name)
    else:
        model_img_name = get_hash(file_base_name) + ".sqfs"
        create_squash_fs_image(model_dir, model_img_name)
    print("Time taken to create fs image", time.time() - start)
    image_src_path = os.path.join(os.getcwd(), model_img_name)
    image_dst_path = os.path.join(os.getcwd(), "assets", model_img_name)
    shutil.move(image_src_path, image_dst_path)
    
    assest_blob_path = os.path.join(CLI_DIR, DOWNLOAD_ASSET_TEMPLATE_PATH)
    with open(assest_blob_path) as f:
        asset_template = json.load(f)
        asset_template["name"] = args.model_name
        asset_template["implementations"][0]["assetUri"] = model_img_name
        print("Copy the following in the assets section of package.json")
        print(json.dumps(asset_template, indent=4))

def create_template_file(path, file_name, template_path, is_json=True):
    file_path = os.path.join(path, file_name)
    file_template = {}
    if is_json:
        file_template = json.loads(pkg_resources.resource_string(pypi_package_name, template_path))
        with open(file_path, "w") as f:
            json.dump(file_template, f, indent=4)
    else:
        file_template = pkg_resources.resource_string(pypi_package_name, template_path)
        with open(file_path, "wb") as f:
            f.write(file_template)

@verify_cwd_is_project_root
def create_package():
    package_types = ["Container", "Model", "Camera", "Datasink"]
    if args.type not in package_types:
        throw_error(1, None, "Invalid package type. Provided type " + args.type + " not in " + ' '.join(package_types))
    account_id = get_aws_account_id()
    package_version = args.version
    pkg_path = os.path.join(os.getcwd(), "packages", account_id + "-" + args.name + "-" + package_version)
    if os.path.exists(pkg_path):
        print("Package " + args.name + " already exists")
        return
    os.mkdir(pkg_path)

    if args.camera or args.type == "Camera":
        create_template_file(pkg_path, 'package.json', CAMERA_PACAKAGE_TEMPLATE_PATH)
        default_node_name = args.name
        full_interface_path = account_id + "::" + args.name + ".rtsp_interface"
        add_graph_node(default_node_name, full_interface_path)
    elif args.type == "Datasink":
        create_template_file(pkg_path, 'package.json', DATASINK_PACAKAGE_TEMPLATE_PATH)
        default_node_name = args.name
        full_interface_path = account_id + "::" + args.name + ".sink_interface"
        add_graph_node(default_node_name, full_interface_path)
    elif args.model or args.type == "Model":
        create_template_file(pkg_path, 'package.json', PACKAGE_TEMPLATE_PATH)
        create_template_file(pkg_path, 'descriptor.json', DESCRIPTOR_MODEL_TEMPLATE_PATH)
    else: #Container Package
        src_dir = os.path.join(pkg_path, "src")
        os.mkdir(src_dir)
        create_template_file(pkg_path, 'descriptor.json', DESCRIPTOR_CONTAINER_TEMPLATE_PATH)
        create_template_file(pkg_path, 'package.json', PACKAGE_TEMPLATE_PATH)
        create_template_file(pkg_path, 'Dockerfile', DOCKERFILE_PATH, is_json=False)
    
    package_json_path = os.path.join(pkg_path, "package.json")
    with open(package_json_path, "r+") as f:
        package_json = json.load(f)
        package_json["nodePackage"]["name"] = args.name
        package_json["nodePackage"]["description"] = "Default description for package " + args.name
        f.seek(0)
        json.dump(package_json, f, indent=4)
        f.truncate()

    graphs_path = os.path.join(os.getcwd(), "graphs")
    graph_dir = [a for a in os.listdir(graphs_path) if os.path.isdir(os.path.join(graphs_path, a))][0]
    graph_json_path = os.path.join(graphs_path, graph_dir, "graph.json")
    with open(graph_json_path, "r+") as f:
        graph_json = json.load(f)
        graph_json["nodeGraph"]["packages"].append({
                            "name": account_id + "::" + args.name,
                            "version": package_version
                        })
        f.seek(0)
        json.dump(graph_json, f, indent=4)
        f.truncate()

    #Check if the package already exists
    status, output = execute(['aws panorama describe-package --output json --package-id packageName/' + args.name])
    if status != 0: #Create new package if the package doesn't exist
        status, output = execute(['aws panorama create-package --output json --package-name ' + args.name])
        #Don't complain even if this fails as package-application will handle that. Allows offline app development.

    print('Successfully created package ' + args.name)

@verify_cwd_is_project_root
def add_abstract_camera():
    graphs_path = os.path.join(os.getcwd(), "graphs")
    graph_dir = [a for a in os.listdir(graphs_path) if os.path.isdir(os.path.join(graphs_path, a))][0]
    graph_json_path = os.path.join(graphs_path, graph_dir, "graph.json")
    with open(graph_json_path, "r+") as f:
        graph_json = json.load(f)
        abstract_camera_exists = False
        for package in graph_json["nodeGraph"]["packages"]:
            if package["name"] == "panorama::abstract_rtsp_media_source" and  package["version"] == "1.0":
                abstract_camera_exists = True
        if not abstract_camera_exists:
            graph_json["nodeGraph"]["packages"].append({
                                "name": "panorama::abstract_rtsp_media_source",
                                "version": "1.0"
                            })
        for node in graph_json["nodeGraph"]["nodes"]:
            if node["name"] == args.name:
                print("Package " + args.name + " already exists")
                return
        graph_json["nodeGraph"]["nodes"].append({
                            "name": args.name,
                            "interface": "panorama::abstract_rtsp_media_source.rtsp_v1_interface",
                            "overridable": True,
                            "launch": "onAppStart",
                            "decorator": {
                                "title": "Camera " + args.name,
                                "description": "Default description for camera " + args.name
                            }
                        })
        f.seek(0)
        json.dump(graph_json, f, indent=4)
        f.truncate()

@verify_cwd_is_project_root
def add_data_sink():
    graphs_path = os.path.join(os.getcwd(), "graphs")
    graph_dir = [a for a in os.listdir(graphs_path) if os.path.isdir(os.path.join(graphs_path, a))][0]
    graph_json_path = os.path.join(graphs_path, graph_dir, "graph.json")
    with open(graph_json_path, "r+") as f:
        graph_json = json.load(f)
        data_sink_exists = False
        for package in graph_json["nodeGraph"]["packages"]:
            if package["name"] == "panorama::hdmi_data_sink" and  package["version"] == "1.0":
                data_sink_exists = True
        if not data_sink_exists:
            graph_json["nodeGraph"]["packages"].append({
                                "name": "panorama::hdmi_data_sink",
                                "version": "1.0"
                            })
        for node in graph_json["nodeGraph"]["nodes"]:
            if node["name"] == args.name:
                print("Package " + args.name + " already exists")
                return
        graph_json["nodeGraph"]["nodes"].append({
                            "name": args.name,
                            "interface": "panorama::hdmi_data_sink.hdmi0",
                            "overridable": False,
                            "launch": "onAppStart"
                        })
        f.seek(0)
        json.dump(graph_json, f, indent=4)
        f.truncate()

def add_panorama_package():
    if args.type == "camera":
        add_abstract_camera()
    elif args.type == "data_sink":
        add_data_sink()
    else:
        print("Please enter a valid type, camera or data_sink")

def create_fs():
    start = time.time()
    file_name = os.path.basename(args.file_path)
    file_base_name = file_name.split('.')[0]
    if args.ext4fs:
        model_img_name = file_base_name + ".img"
        create_ext4_fs_image(args.file_path, model_img_name)
    else:
        model_img_name = file_base_name + ".sqfs"
        create_squash_fs_image(args.file_path, model_img_name)
    print("Time taken to create fs image", time.time() - start)

def init_project():
    project_path = os.path.join(os.getcwd(), args.name)
    if os.path.exists(project_path):
        print("Application " + args.name + " already exists")
        return
    os.mkdir(project_path)
    project_dirs = ['assets', 'graphs', 'packages']
    for dir in project_dirs:
        dir_path = os.path.join(project_path, dir)
        os.mkdir(dir_path)

    new_graph_dir = os.path.join(project_path, "graphs", args.name)  
    os.mkdir(new_graph_dir)
    create_template_file(new_graph_dir, 'graph.json', GRAPH_TEMPLATE_PATH)
    print('Successfully created the project skeleton at ' + project_path)

@verify_cwd_is_project_root
def package_application():
    graphs_path = os.path.join(os.getcwd(), "graphs")
    graph_dir = [a for a in os.listdir(graphs_path) if os.path.isdir(os.path.join(graphs_path, a))][0]
    graph_json_path = os.path.join(graphs_path, graph_dir, "graph.json")
    with open(graph_json_path, "r+") as f:
        graph_json = json.load(f)
    assets_dir = os.path.join(os.getcwd(), 'assets')
    packages_dir = os.path.join(os.getcwd(), 'packages')
    if args.packages_path:
        package_dirs = list(map(os.path.normpath, args.packages_path))
        package_list = list(map(os.path.basename, package_dirs))
    else:
        package_list = [a for a in os.listdir(packages_dir) if os.path.isdir(os.path.join(packages_dir, a))]
    registered_packages = {}
    failed_packages = {}
    for package in package_list:
        package_path = os.path.join(packages_dir, package)
        if len(package.split('-')) < 3:
            print(package_path + " is not a package, ignoring")
            continue
        package_name = get_package_name(package_path)
        panorama_package_list = []
        print("Uploading package " + package_name)
    
        #Check if the package already exists
        status, output = execute(['aws panorama describe-package --output json --package-id packageName/' + package_name])
        if status != 0: #Create new package if the package doesn't exist
            status, output = execute(['aws panorama create-package --output json --package-name ' + package_name])
            if status != 0:
                throw_error(status, output, 'Error creating package ' + package + ' on Panorama')      
        package_info = json.loads(output)
        package_id = package_info["PackageId"]
        storage_location = package_info["StorageLocation"]
        package_json_path = os.path.join(package_path, "package.json")
        with open(package_json_path, "r+") as f:
            package_json = json.load(f)
        package_json_hash = ""
        package_version = "1.0" #Default package version, updated below if found
        with open(package_json_path,"rb") as f:
            bytes = f.read()
            package_json_hash = hashlib.sha256(bytes).hexdigest()
            if "version" in package_json["nodePackage"].keys():
                package_version = package_json["nodePackage"]["version"]  
        
        #Ignore upload if the same patch version of a package is already registered
        status, output = execute(['aws panorama describe-package-version --output json --package-id ' + package_id + ' --package-version ' + package_version])
        if status != 0:
            print("Package Version " + package_version + " is not yet registered, preparing upload")
        else:
            package_version_info = json.loads(output)
            patch_version = package_version_info["PatchVersion"]
            if patch_version == package_json_hash and package_version_info["Status"] in ["REGISTER_COMPLETED", "REGISTER_PENDING"]:
                print("Patch Version " + patch_version + " already registered, ignoring upload")
                registered_packages[package_id] = {"package-version":package_version, "patch-version":patch_version, "package-name":package_name}
                continue
            else:
                print("Patch version for the package " + package_json_hash)

        assets_not_found = []
        for asset in package_json["nodePackage"]["assets"]:
            for implementation in asset["implementations"]:
                if implementation["type"] == "system": #Ignore uploading system assets
                    continue
                asset_uri = implementation["assetUri"]
                asset_path = os.path.join(assets_dir, asset_uri)
                status, output = execute(["aws s3api head-object --bucket " + storage_location["Bucket"] + " --key " + storage_location["BinaryPrefixLocation"] + "/" + asset_uri])
                if status != 0:
                    status, output = execute(["aws s3 cp " + get_path_with_espace_characters(asset_path) + " s3://" + storage_location["Bucket"] + "/" + storage_location["BinaryPrefixLocation"] + "/" + asset_uri + " --acl bucket-owner-full-control"], runtime_print_output=True)
                    if status != 0:
                        throw_error(status, output, 'Error uploading package asset ' + asset_uri + ' to S3', exit=False)
                        assets_not_found.append(asset_uri)
                else:
                    print("Asset " + asset_uri + " already exists, ignoring upload")
                if "descriptorUri"  in implementation:
                    descriptor_uri = implementation["descriptorUri"]
                    descriptor_path = os.path.join(assets_dir, descriptor_uri)
                    status, output = execute(["aws s3api head-object --bucket " + storage_location["Bucket"] + " --key " + storage_location["BinaryPrefixLocation"] + "/" + descriptor_uri])
                    if status != 0:
                        status, output = execute(["aws s3 cp " + get_path_with_espace_characters(descriptor_path) + " s3://" + storage_location["Bucket"] + "/" + storage_location["BinaryPrefixLocation"] + "/" + descriptor_uri + " --acl bucket-owner-full-control"], runtime_print_output=True)
                        if status != 0:
                            throw_error(status, output, 'Error uploading package descriptor ' + descriptor_uri + ' to S3', exit=False)
                            assets_not_found.append(descriptor_uri)
                    else:
                        print("Descriptor " + descriptor_uri + " already exists, ignoring upload")
        if len(assets_not_found) > 0:
            print("Assets not found for package " + package_name + ". Skipping registration")
            failed_packages[package_name] = assets_not_found
            continue
        status, output = execute(["aws s3api head-object --bucket " + storage_location["Bucket"] + " --key " + storage_location["ManifestPrefixLocation"] + "/" + package_version + "/" + package_json_hash + ".json"])
        if status != 0:
            status, output = execute(["aws s3api put-object --bucket " + storage_location["Bucket"] + " --key " + storage_location["ManifestPrefixLocation"] + "/" + package_version + "/" + package_json_hash + ".json --body " + get_path_with_espace_characters(package_json_path) + " --acl bucket-owner-full-control"], runtime_print_output=True)
            if status != 0:
                throw_error(status, output, 'Error uploading package json to S3')
        status, output = execute(["aws panorama register-package-version --output json --package-id " + package_info["PackageId"] + " --package-version " + package_version + " --patch-version " + package_json_hash + " --mark-latest"], runtime_print_output=True)
        if status != 0:
            throw_error(status, output, 'Error registering package json with Panorama')
        print("Called register package version for " + package_name + " with patch version " + package_json_hash)
        registered_packages[package_id] = {"package-version":package_version, "patch-version":package_json_hash, "package-name":package_name}
    
    #Waiting for register-package-version to finish on cloud
    for package_id in registered_packages:
        pending = True
        while pending:
            status, output = execute(["aws panorama describe-package-version --output json --package-id " + package_id + " --package-version " + registered_packages[package_id]["package-version"] + " --patch-version " + registered_packages[package_id]["patch-version"]])
            if status != 0:
                throw_error(status, output, "Error while checking register package version status for " + registered_packages[package_id]["package-name"] + " with patch version " + registered_packages[package_id]["patch-version"])
            else:
                package_version_info = json.loads(output)
                if package_version_info["Status"] == "REGISTER_COMPLETED":
                    print("Register patch version complete for " + registered_packages[package_id]["package-name"] + " with patch version " + registered_packages[package_id]["patch-version"])
                    pending = False
                    continue
                elif package_version_info["Status"] == "REGISTER_PENDING":
                    print("Waiting for register package version to finish for " + registered_packages[package_id]["package-name"])
                elif package_version_info["Status"] == "FAILED":
                    print(package_version_info["StatusDescription"])
                    print("Register patch version failed for " + registered_packages[package_id]["package-name"] + " with patch version " + registered_packages[package_id]["patch-version"])
                    sys.exit("Fix the error and run panorama-cli package-application again")
            time.sleep(30)
    if len(failed_packages) != 0:
        for failed_package in failed_packages:
            print("Assets ", failed_packages[failed_package], " not found for package " + failed_package)
        print("Fix the errors and run panorama-cli package-application again to re-try upload for failed packages")
    else:
        print("All packages uploaded and registered successfully")
    

@verify_cwd_is_project_root
def import_application():
    account_id = get_aws_account_id()
    packages_dir = os.path.join(os.getcwd(), 'packages')
    package_list = [a for a in os.listdir(packages_dir) if os.path.isdir(os.path.join(packages_dir, a))]
    for package in package_list:
        if len(package.split('-')) < 3:
            continue
        curr_package_path = os.path.join(packages_dir, package)
        old_account_id = package.split('-')[0]
        new_package = package.replace(old_account_id, account_id)
        new_package_path = os.path.join(packages_dir, new_package)
        shutil.move(curr_package_path, new_package_path)
    
    graphs_path = os.path.join(os.getcwd(), "graphs")
    graph_dir = [a for a in os.listdir(graphs_path) if os.path.isdir(os.path.join(graphs_path, a))][0]
    graph_json_path = os.path.join(graphs_path, graph_dir, "graph.json")
    graph_json = {}
    with open(graph_json_path, "r+") as f:
        graph_json = json.load(f)
    old_account_id = ""
    for package in graph_json["nodeGraph"]["packages"]:
        if package["name"].split("::")[0] == "panorama":
            continue
        old_account_id = package["name"].split("::")[0]
        package["name"] = package["name"].replace(old_account_id, account_id)
    for node in graph_json["nodeGraph"]["nodes"]:
        node["interface"] = node["interface"].replace(old_account_id, account_id)
    with open(graph_json_path, "w") as f:
        json.dump(graph_json, f, indent=4)
    print("Sucessfully imported application")
    
def main(passed_args=None):
    parser = argparse.ArgumentParser()
    parser.add_argument("--version", action='version', version='%(prog)s ' + __version__, help="Version of panorama-cli")
    subparsers = parser.add_subparsers()

    init_project_parser = subparsers.add_parser("init-project", help="Initializes and creates the directory structure for the project")
    init_project_parser.add_argument("--name", required=True, help="Name of the project")
    init_project_parser.set_defaults(func=init_project)

    add_model_parser = subparsers.add_parser("add-raw-model", help="Add raw model artifacts")
    add_model_parser.add_argument("--model-asset-name", required=True, help="Name for model being downloaded")
    add_model_group = add_model_parser.add_mutually_exclusive_group(required=True)
    add_model_group.add_argument("--model-s3-uri", help="S3 URI of the raw model")
    add_model_group.add_argument("--model-local-path", help="Local path of the raw model")
    add_model_parser.add_argument("--descriptor-path", required=True, help="Path for the descriptor json for the model")
    add_model_parser.add_argument("--packages-path", nargs="+", type=str, default=[], help="List of packages where this model will be used. Downloaded model asset will be directly defined in those packages if provided")
    add_model_parser.set_defaults(func=add_raw_model)

    create_package_parser = subparsers.add_parser("create-package", help="Create a new package")
    create_package_parser.add_argument("--name", required=True, help="Name of the package")
    create_package_parser.add_argument('--type', default="Container", help="Type of the package. Model or Camera or Datasink or Container(Default)")
    create_package_parser.add_argument("--version", default="1.0", help="Version of the package, 1.0 by default")
    create_package_parser.add_argument('-camera', action='store_true')
    create_package_parser.add_argument('-model', action='store_true')
    create_package_parser.set_defaults(func=create_package)

    create_package_parser = subparsers.add_parser("add-panorama-package", help="Add a package provided by Panorama")
    create_package_parser.add_argument("--type", required=True, help="Type of the package, camera or data_sink")
    create_package_parser.add_argument("--name", required=True, help="Name of the package node")
    create_package_parser.set_defaults(func=add_panorama_package)

    #Deprecated, need to remove. Replaced with build-container command
    build_parser = subparsers.add_parser("build", help="(Deprecated) Use build-container instead. Build the package.")
    build_parser.add_argument("--container-asset-name",  required=True, help="Name of the package")
    build_parser.add_argument("--package-path",  required=True, help="Path for the package to be built")
    build_parser.set_defaults(func=build)

    build_parser = subparsers.add_parser("build-container", help="Build the package")
    build_parser.add_argument("--container-asset-name",  required=True, help="Name of the package")
    build_parser.add_argument("--package-path",  required=True, help="Path for the package to be built")
    build_parser.set_defaults(func=build)

    export_parser = subparsers.add_parser("export-container", help="Export a pre-built docker container to your package")
    export_parser.add_argument("--container-asset-name",  required=True, help="Name of the package")
    export_parser.add_argument("--container-image-uri",  required=True, help="Uri of the Docker Image")
    export_parser.add_argument("--package-path",  required=True, help="Path for the package to be built")
    export_parser.set_defaults(func=export)

    package_application_parser = subparsers.add_parser("package-application", help="Uploads all the application assets to panorama cloud account along with all the manifests")
    package_application_parser.add_argument("--packages-path", nargs="+", type=str, default=[], help="Uploads only these packages if provided")
    package_application_parser.set_defaults(func=package_application)

    import_application_parser = subparsers.add_parser("import-application", help="Import an application created by someone else. Updates account id at all relevant places")
    import_application_parser.set_defaults(func=import_application)

    if passed_args is None and len(sys.argv) == 1:
        print(parser.print_help())
        sys.exit(0)

    global args
    args = parser.parse_args(passed_args) #Picks up sys.argv directly if passed_args is None
    args.func()

if __name__ == '__main__':
    json.load = json_load_decorator(json.load)
    main()