#!python
#
# Licensed under the Apache License, Version 2.0 (the "License");
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0OA
#
# Authors:
# - Wen Guan, <wen.guan@cern.ch>, 2024 - 2025


"""
Run workflow.
"""

from __future__ import print_function

import argparse
import base64
import logging
import json
import os
import sys
import time
import traceback

from idds.common.utils import json_loads, decode_base64 as idds_decode_base64


logging.basicConfig(stream=sys.stderr,
                    level=logging.DEBUG,
                    format='%(asctime)s\t%(threadName)s\t%(name)s\t%(levelname)s\t%(message)s')
logging.Formatter.converter = time.gmtime


def get_parser(program):
    """
    Return the argparse parser.
    """
    oparser = argparse.ArgumentParser(prog=os.path.basename(program), add_help=True)

    # common items
    oparser.add_argument('--pre_setup', dest='pre_setup', help="The pre_setup.")
    oparser.add_argument('--setup', dest='setup', help="The setup.")
    oparser.add_argument('--post_script', dest='post_script', help="Post script to run after workflow (bash code)")
    # oparser.add_argument('--output_map', dest='output_map', help="The output map.")
    # oparser.add_argument('--inputs', dest='inputs', help="The input list.")
    # oparser.add_argument('--input_map', dest='input_map', help="The input map.")
    oparser.add_argument('run_args', nargs=argparse.REMAINDER, help="All other arguments")
    return oparser


def get_workflow_setup_script(extra_env=None):

    script = """#!/bin/bash

echo "current dir: " $PWD

which python
which python3

# cd ${current_dir}

# if it's in a container, this part is needed again to setup the environment.
current_dir=$PWD
export PATH=${current_dir}:${current_dir}/tmp_bin:${current_dir}/bin:$PATH
export PYTHONPATH=${current_dir}:${current_dir}/lib_py:$PYTHONPATH

if ! command -v python &> /dev/null
then
    echo "no python, link python3 to python"
    # alias python=python3
    ln -fs $(which python3) ./python
fi

if [ -f ${current_dir}/x509_proxy ]; then
    export X509_USER_PROXY=${current_dir}/x509_proxy
fi

"""
    if extra_env:
        script += "\n# Set extra environment variables from context\n"
        script += "echo 'Setup PanDA environments'\n"
        for k, v in extra_env.items():
            script += f"export {k}='{v}'\n"

    script += "env\n"
    return script


# This part is possbile to run on SL7 with python2
def decode_base64(sb, remove_quotes=False):
    try:
        if isinstance(sb, str):
            if sys.version_info.major == 2:
                # In python 2, str is already bytes
                sb_bytes = sb
            else:
                sb_bytes = bytes(sb, 'ascii')
        elif isinstance(sb, bytes):
            sb_bytes = sb
        else:
            return sb
        decode_str = base64.b64decode(sb_bytes).decode("utf-8")
        # remove the single quotes afeter decoding
        if remove_quotes:
            return decode_str[1:-1]
        return decode_str
    except Exception as ex:
        logging.error("decode_base64 %s: %s" % (sb, ex))
        return sb


def create_run_workflow_cmd(run_args, extra_env=None, post_script=None):
    # current_dir = os.getcwd()

    run_script = "./run_workflow.sh"
    setup_script = get_workflow_setup_script(extra_env)
    script = setup_script + "\n"
    script += " ".join(f'"{str(arg)}"' for arg in run_args)
    script += "ret=$?\n"

    # Append post_script directly
    if post_script:
        script += "\n\n# Post script section\n"
        script += post_script

    script += "\n\n# Exit with the return code of the workflow command\n"
    script += "exit $ret\n"

    logging.debug("script: ")
    logging.debug(script)

    with open(run_script, 'w') as f:
        f.write(script)
    os.chmod(run_script, 0o755)

    return run_script


def extract_context_from_run_args(run_args):
    """Extract --context value from run_args and initialize/setup source files."""
    context = None
    extra_env = None
    post_script = None
    args_file = None
    for i, arg in enumerate(run_args):
        if arg == '--context' and i + 1 < len(run_args):
            context_str = run_args[i + 1]
            context_str = idds_decode_base64(context_str)
            context = json_loads(context_str)
        elif arg == '--args_file' and i + 1 < len(run_args):
            args_file = run_args[i + 1]

    if args_file and os.path.exists(args_file):
        try:
            with open(args_file, 'r') as f:
                data = f.read()
            args_content = idds_decode_base64(data)
            args_content = json_loads(args_content)
            if 'context' in args_content and args_content['context']:
                context_str = idds_decode_base64(args_content['context'])
                context = json_loads(context_str)
        except Exception as ex:
            logging.warning("Failed to load context from args_file: %s" % ex)

    if context is not None:
        try:
            logging.info("Initializing context and setting up source files...")
            context.initialize()
            context.setup_source_files()

            logging.info("loading panda idds envs...")
            extra_env = context.get_panda_idds_env()
            logging.info(f"Panda idds envs loaded: {extra_env}")

            post_script = context.post_script
            logging.info(f"Post script from context: {post_script}")
        except Exception as ex:
            logging.warning("Failed to setup source files: %s" % ex)

    return context, extra_env, post_script


def process_args(args):
    logging.debug("pre_setup:")
    logging.debug(args.pre_setup)
    logging.debug("setup: ")
    logging.debug(args.setup)
    logging.debug("post_script: ")
    logging.debug(args.post_script)
    # logging.debug("output_map: ")
    # logging.debug(args.output_map)
    # logging.debug("inputs: ")
    # logging.debug(args.inputs)
    logging.debug("run_args:")
    logging.debug(args.run_args)

    cmd = ""
    if args.pre_setup:
        pre_setup = json.loads(decode_base64(args.pre_setup, remove_quotes=False))
        if pre_setup:
            cmd = cmd + pre_setup
    if args.setup:
        setup = json.loads(decode_base64(args.setup, remove_quotes=False))
        if setup:
            cmd = cmd + " " + setup
    if args.post_script:
        post_script = json.loads(decode_base64(args.post_script, remove_quotes=False))
    else:
        post_script = None

    # Initialize context and setup source files before running the workflow
    _, extra_env, post_script_1 = extract_context_from_run_args(args.run_args)

    # Merge post_script and post_script_1
    merged_post_script = None
    if post_script and post_script_1:
        merged_post_script = f"{post_script}\n{post_script_1}"
    elif post_script:
        merged_post_script = post_script
    elif post_script_1:
        merged_post_script = post_script_1

    run_script = create_run_workflow_cmd(args.run_args, extra_env=extra_env, post_script=merged_post_script)
    cmd = cmd + f" {run_script}"
    return cmd


if __name__ == '__main__':
    arguments = sys.argv[1:]

    oparser = get_parser(sys.argv[0])
    # argcomplete.autocomplete(oparser)

    # logging.debug("all args:")
    # logging.debug(sys.argv)
    logging.debug("arguments: ")
    logging.debug(arguments)
    args = oparser.parse_args(arguments)

    try:
        start_time = time.time()
        new_command = process_args(args)
        print(new_command)
        end_time = time.time()
        logging.info("Completed processing args in %-0.4f sec." % (end_time - start_time))
        sys.exit(0)
    except Exception as error:
        logging.error("Strange error: {0}".format(error))
        logging.error(traceback.format_exc())
        sys.exit(-1)
