#!/usr/bin/env python3
import os
import sys
import subprocess

import ats
from ats.util.generic_utils import execute, clean_old_sandboxes, clean_old_ats_log_dirs, \
                      set_machine_type_based_on_sys_type, get_interactive_partition

myats = os.path.join(sys.exec_prefix, 'bin', 'ats')
if os.path.exists(myats):
    print(f"Most excellent! Found {myats}")
else:
    sys.exit(f"Bummer! {myats} does not exist")

# -------------------------------------------------------------------------------------------------
# Get the sys_type
# See if there is a slurm_job_id (which will be the case if the nodes are pre-allocated on chaos
# -------------------------------------------------------------------------------------------------
sys_type = os.getenv("SYS_TYPE")
slurm_job_id = os.getenv("SLURM_JOB_ID")
lsb_batch_jid = os.getenv("LSB_BATCH_JID")


# Look for file.ats on the command line.  If one is not found, we will attempt to create one,
# but if one is specified, then verify it exists and then use it.
#
# Look for exclusive option
test_ats_file = ""
exclusive_found = False
for index, arg in enumerate(sys.argv):
    print(arg)
    if arg.find('=') >= 0:
        (key, val) = arg.split('=', 1)
        sys.argv[index] = key + '="' + val + '"'
    elif arg.find('exclusive') >= 0:
        exclusive_found = True
    elif arg.endswith('.ats'):
        test_ats_file = arg
        if not os.path.exists(test_ats_file):
            sys.exit("Bummer! Did not find test file %s" % (test_ats_file))

# -------------------------------------------------------------------------------------------------
# Set default numNodes. Allow user to override it with the --numNodes=99 option
# Also, if numNodes is found, after saving it for the salloc command we can remove it from the
#   argument list, as it will not be needed further.
# -------------------------------------------------------------------------------------------------
if 'blueos' in sys_type:
    numNodes = 2
else:
    numNodes = 3

for index, arg in enumerate(sys.argv):
    if arg.find('numNodes') >= 0:
        (key, val) = arg.split('=', 1)
        if val.startswith('"') and val.endswith('"'):   # strip off possible quotes from integer
            val = val[1:-1]
        if val.isdigit():
            numNodes = int(val)
        else:
            print(f"ATS ERROR '{arg}' is invalid")
            sys.exit(-1)
        if   'chaos' in sys_type:
            del sys.argv[index]
        elif 'toss'  in sys_type:
            del sys.argv[index]
        elif 'blueos'  in sys_type:
            del sys.argv[index]
        break

# -------------------------------------------------------------------------------------------------
# call ASC routines to set the machine type based on the sys type and get the name of the interactive
# partition
# -------------------------------------------------------------------------------------------------
set_machine_type_based_on_sys_type()
interactive_partition = get_interactive_partition()

# -------------------------------------------------------------------------------------------------
#  Startup
# -------------------------------------------------------------------------------------------------
if __name__ == "__main__":

    cmd = myats + " --debug " + " " + ' '.join(sys.argv[1:])

    create_test_ats_py = os.path.join(os.getcwd(), "create_test_ats.py")

    # If user did not specify a file, then we will default to test.ats, or attempt to create
    # it if it does not exist using a script the user provides in this directory
    if test_ats_file == "":
        test_ats_file = os.path.join(os.getcwd(), "test.ats")
        cmd = cmd + " " + test_ats_file

    if os.path.exists(test_ats_file):
        print(f"Found {test_ats_file}.")
    elif os.path.exists(create_test_ats_py):
        print(f"Found {create_test_ats_py}.")
        return_code = execute(create_test_ats_py, None, True)
        if return_code != 0:
            sys.exit(1)
    else:
        sys.exit("Bummer! Did not find %s or %s" % (test_ats_file, create_test_ats_py))

    clean_old_sandboxes()
    clean_old_ats_log_dirs()

    if 'blueos' in sys_type:
        if lsb_batch_jid is None:
            cmd = 'bsub --private-launch -nnodes %d -Is -XF -W 120 -core_isolation 2 -q %s %s' % (numNodes, interactive_partition, cmd)
    elif exclusive_found is True:
        print("Running script on login node -- each test will have a separate srun allocation")
        cmd = "%s --numNodes=%d" % (cmd, numNodes)
    elif slurm_job_id is None:
        cmd = 'salloc -N %d -p %s --exclusive %s' % (numNodes, interactive_partition, cmd)
    if numNodes == 1:
        cmd += ' --bypassSerialMachineCheck'

    print(f"Executing: {cmd}")
    subprocess.run(cmd.split(), text=True)
