#!/projects/jp/engineering/qbatch/env/bin/python
from __future__ import print_function
import argparse
import math
import os
import os.path
import re
import subprocess
import stat
import sys
import fnmatch
import errno

# setup defaults (let environment override)
PPJ = os.environ.get("QBATCH_PPJ", 1)
CHUNKSIZE = os.environ.get("QBATCH_CHUNKSIZE", PPJ)
CORES = os.environ.get("QBATCH_CORES", PPJ)
SYSTEM = os.environ.get("QBATCH_SYSTEM", "pbs")
NODES = os.environ.get("QBATCH_NODES", 1)
PE = os.environ.get("QBATCH_PE", "smp")
MEMVARS = os.environ.get("QBATCH_MEMVARS", "mem")
MEM = os.environ.get("QBATCH_MEM", "0")
SCRIPT_FOLDER = os.environ.get("QBATCH_SCRIPT_FOLDER", ".qbatch/")

# environment vars to ignore when copying the environment to the job script
IGNORE_ENV_VARS = ['PWD', 'SGE_TASK_ID', 'PBS_ARRAYID', 'ARRAY_IND',
                   'BASH_FUNC_*']

PBS_HEADER_TEMPLATE = """
{shebang}
#PBS -l nodes={nodes}{highmem}:ppn={ppj}
#PBS -j oe
#PBS -o {logdir}
#PBS -d {workdir}
#PBS -N {job_name}
#PBS {o_memopts}
#PBS {o_array}
#PBS {o_walltime}
#PBS {o_dependencies}
#PBS {o_options}
{env}
{header_commands}
ARRAY_IND=$PBS_ARRAYID
""".strip()

SGE_HEADER_TEMPLATE = """
{shebang}
#$ {ppj}
#$ -j y
#$ -o {logdir}
#$ -wd {workdir}
#$ -N {job_name}
#$ {o_memopts}
#$ {o_array}
#$ {o_walltime}
#$ {o_dependencies}
#$ {o_options}
{env}
{header_commands}
ARRAY_IND=$SGE_TASK_ID
""".strip()

LOCAL_TEMPLATE = """
{shebang}
{env}
{header_commands}
cd {workdir}
""".strip()


def run_command(command, logfile=None):
    # Run command and collect stdout
    # http://blog.endpoint.com/2015/01/getting-realtime-output-using-python.html # noqa
    process = subprocess.Popen(
        command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    if logfile:
        filehandle = open(logfile, 'wb')
    while True:
        output = process.stdout.readline()
        if output.decode('UTF-8') == '' and process.poll() is not None:
            break
        if output and logfile:
            print(output.strip().decode('UTF-8'))
            filehandle.write(output.strip())
            filehandle.write("\n".encode('UTF-8'))
        elif output:
            print(output.strip().decode('UTF-8'))
    rc = process.poll()
    if logfile:
        filehandle.close()
    return rc


def mkdirp(*p):
    """Like mkdir -p"""
    path = os.path.join(*p)

    try:
        os.makedirs(path)
    except OSError as exc:
        if exc.errno == errno.EEXIST:
            pass
        else:
            raise
    return path


def valid_file(string):
    """Checks argument is a file that exists or is -"""
    if string != '-' and not os.path.isfile(string):
        raise argparse.ArgumentTypeError(
            "The file {0} does not exist".format(string))
    return string


def positive_int(string):
    """Checks agument is a positive integer"""
    msg = "Must be a positive integer"

    try:
        value = int(string)
    except ValueError:
        raise argparse.ArgumentTypeError(msg)

    if value < 1:
        raise argparse.ArgumentTypeError(msg)
    return value


def int_or_percent(string):
    """Checks argument is an integer or integer percentage"""
    if not re.match("^([-+]?\d+|^\d+%)$", string):
        msg = "Must be an integer or positive integer percentage"
        raise argparse.ArgumentTypeError(msg)
    return string


def pbs_find_jobs(patterns):
    """Finds jobs with names matching a given list of patterns

    Returns a list of job IDs.

    Raises an Exception if there is an error running the 'qstat' command or
    parsing its output.
    """
    if not patterns:
        return []

    if isinstance(patterns, str):
        patterns = [patterns]

    import xml.etree.ElementTree as ET

    output = subprocess.check_output(['qstat', '-x'])
    if not output:
        print(
            "WARN: Dependencies specified but no running jobs found",
            file=sys.stderr)
        return []
    tree = ET.fromstring(output)

    matches = []
    for job in tree:
        jobid = job.find('Job_Id').text
        name = job.find('Job_Name').text

        for pattern in patterns:
            if fnmatch.fnmatch(name, pattern):
                matches.append(jobid)

    return matches


def which(program):
    # Check for existence of important programs
    # Stolen from
    # http://stackoverflow.com/questions/377017/test-if-executable-exists-in-python # noqa
    def is_exe(fpath):
        return os.path.isfile(fpath) and os.access(fpath, os.X_OK)

    fpath, fname = os.path.split(program)
    if fpath:
        if is_exe(program):
            return program
    else:
        for path in os.environ["PATH"].split(os.pathsep):
            path = path.strip('"')
            exe_file = os.path.join(path, program)
            if is_exe(exe_file):
                return exe_file

    return None

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        usage="%(prog)s [options] <command_file>",
        description="""Submits a list of commands to a queueing system.
        The list of commands can be broken up into 'chunks' when submitted, so
        that the commands in each chunk run in parallel (using GNU parallel).
        The job script(s) generated by %(prog)s are stored in the folder
        {0}""".format(SCRIPT_FOLDER),
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        "command_file", type=valid_file,
        help="""An input file containing a list of shell commands to be
        submitted. Use - to read the command list from STDIN""")
    parser.add_argument(
        "-w", "--walltime",
        help="""Maximum walltime for an array job element or individual job""")
    parser.add_argument(
        "-c", "--chunksize", default=CHUNKSIZE, type=positive_int,
        help="""Number of commands from the command list that are wrapped into
        each job""")
    parser.add_argument(
        "-j", "--cores", default=CORES, type=int_or_percent,
        help="""Number of commands each job runs in parallel. If the chunk size
        (-c) is smaller than -j then only chunk size commands will run in
        parallel. This option can also be expressed as a percentage (e.g.
        100%%) of the total available cores""")
    parser.add_argument(
        "--ppj", default=PPJ, type=positive_int,
        help="""Requested number of processors per job (aka ppn on PBS,
        slots on SGE). Cores can be over subscribed if -j is larger than --ppj
        (useful to make use of hyper-threading on some systems)""")
    parser.add_argument(
        "-N", "--jobname", action="store",
        help="""Set job name (defaults to name of command file, or STDIN)""")
    parser.add_argument(
        "--mem", default=MEM,
        help="""Memory required for each job (e.g. --mem 1G).  This value will
        be set on each variable specified in --memvars. To not set any memory
        requirement, set this to 0""")
    parser.add_argument(
        "-n", action="store_true",
        help="Dry run; nothing is submitted or run")
    parser.add_argument("-v", action="store_true", help="Verbose output")

    group = parser.add_argument_group('advanced options')
    group.add_argument(
        "--afterok", action="append",
        help="""Wait for successful completion of job(s) with name matching
        glob pattern before starting""")
    group.add_argument(
        "-d", "--workdir", default=os.getcwd(),
        help="Job working directory")
    group.add_argument(
        "--logdir", action="store", default="{workdir}/logs",
        help="""Directory to save store log files""")
    group.add_argument(
        "-o", "--options", action="append", default=[],
        help="""Custom options passed directly to the queuing system (e.g
        --options "-l vf=8G". This option can be given multiple times""")
    group.add_argument(
        "--header", action="append",
        help="""A line to insert verbatim at the start of the script, and will
        be run once per job. This option can be given multiple times""")
    group.add_argument(
        "--nodes", default=NODES, type=positive_int,
        help="(PBS-only) Nodes to request per job")
    group.add_argument(
        "--pe", default=PE,
        help="""(SGE-only) The parallel environment to use if more than one
        processor per job is requested""")
    group.add_argument(
        "--memvars", default=MEMVARS,
        help="""A comma-separated list of variables to set with the memory
        limit given by the --mem option (e.g. --memvars=h_vmem,vf)""")
    group.add_argument(
        "--highmem", action="store_true",
        help="(Scinet-only) Submit to high memory nodes")
    group.add_argument(
        "-i", action="store_true",
        help="Submit individual jobs instead of an array job")
    group.add_argument(
        "-b", default=SYSTEM, choices=['pbs', 'sge', 'local'],
        help="""The type of queueing system to use. 'pbs' and 'sge' both make
        calls to qsub to submit jobs. 'local' runs the entire command list
        (without chunking) locally.""")
    group.add_argument(
        "--no-env", action="store_true",
        help="""Do not copy the current environment into the job script(s)""")

    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(1)

    args = parser.parse_args()

    command_file = args.command_file
    walltime = args.walltime
    chunk_size = args.chunksize
    ncores = args.cores
    ppj = args.ppj
    job_name = args.jobname
    mem = args.mem != '0' and args.mem or None
    verbose = args.v
    dry_run = args.n
    afterok_pattern = args.afterok
    workdir = args.workdir
    logdir = args.logdir.format(workdir=workdir)
    options = args.options
    header_commands = args.header and '\n'.join(args.header) or ''
    nodes = args.nodes
    pe = args.pe
    memvars = args.memvars.split(',')
    highmem = args.highmem and ':m32g' or ''
    use_array = not args.i
    system = args.b
    copy_env = not args.no_env

    # Preflight check
    if system == 'sge' or system == 'pbs':
        if not which("qsub"):
            sys.exit("qsub command not found")

    if not which("parallel"):
        sys.exit("parallel command not found")

    mkdirp(logdir)

    # read in commands
    if command_file == '-':
        task_list = sys.stdin.readlines()
        job_name = job_name or 'STDIN'
    else:
        task_list = open(command_file).readlines()
        job_name = job_name or os.path.basename(command_file)

    # compute the number of jobs needed. This will be the number of elements in
    # the array job
    if len(task_list) == 0:
        print("No jobs to submit, exiting", file=sys.stderr)
        sys.exit()

    if system == 'local':
        use_array = False
        num_jobs = 1
        chunk_size = sys.maxsize
    elif len(task_list) <= chunk_size:
        use_array = False
        num_jobs = 1
        print("Number of commands less than chunk size, "
              "building single non-array job", file=sys.stderr)
    else:
        num_jobs = int(math.ceil(len(task_list) / float(chunk_size)))

    # copy the current environment
    env = ''
    if copy_env:
        env = '\n'.join(['export {0}="{1}"'.format(k, v.replace('"', r'\"'))
                         for k, v in os.environ.items()
                         if not any(fnmatch.fnmatch(k, pattern) for pattern
                                    in IGNORE_ENV_VARS)])
        env = env.replace("$", "$$")
        env = "# -- start copied env\n{0}\n# -- end copied env".format(env)

    # make script header
    shebang = '#!/bin/bash'

    if system == 'pbs':
        try:
            matching_jobids = pbs_find_jobs(afterok_pattern)
        except Exception as e:
            sys.exit("Error matching afterok pattern {0}".format(str(e)))

        o_array = use_array and '-t 1-{0}'.format(num_jobs) or ''
        o_walltime = walltime and "-l walltime={0}".format(walltime) or ''
        o_dependencies = matching_jobids and \
            ('-W depend=' + (use_array and 'afterokarray:' or 'afterok:') +
                ':'.join(matching_jobids)) or ''
        o_options = '\n#PBS '.join(options)
        mem_string = ','.join(["{0}={1}".format(var, mem) for var in memvars])
        o_memopts = (mem and mem_string) and '-l {0}'.format(mem_string) or ''

        header = PBS_HEADER_TEMPLATE.format(**vars())

    elif system == 'sge':
        ppj = (ppj > 1) and '-pe {0} {1}'.format(pe, ppj) or ''
        o_array = use_array and '-t 1-{0}'.format(num_jobs) or ''
        o_walltime = walltime and "-l h_rt={0}".format(walltime) or ''
        o_dependencies = afterok_pattern and '-hold_jid \'' + \
            '\',\''.join(afterok_pattern) + '\'' or ''
        o_options = '\n#$ '.join(options)
        mem_string = ','.join(["{0}={1}".format(var, mem) for var in memvars])
        o_memopts = (mem and mem_string) and '-l {0}'.format(mem_string) or ''
        header = SGE_HEADER_TEMPLATE.format(**vars())

    elif system == 'local':
        header = LOCAL_TEMPLATE.format(**vars())

    # emit job scripts
    job_scripts = []
    mkdirp(SCRIPT_FOLDER)
    if use_array:
        script_lines = [
            header,
            'CHUNK_SIZE={0}'.format(chunk_size),
            'CORES={0}'.format(ncores),
            'sed -n "$(( (${ARRAY_IND} - 1) * ${CHUNK_SIZE} + 1 )),'
            '+$(( ${CHUNK_SIZE} - 1 ))p" << EOF | parallel -v -j${CORES}',
            ''.join(task_list),
            'EOF']

        scriptfile = os.path.join(SCRIPT_FOLDER, job_name + ".array")
        script = open(scriptfile, 'w')
        script.write('\n'.join(script_lines))
        script.close()
        job_scripts.append(scriptfile)
    else:
        for chunk in range(num_jobs):
            scriptfile = os.path.join(
                SCRIPT_FOLDER, "{0}.{1}".format(job_name, chunk))
            if len(task_list) == 1:
                script_lines = [
                    header,
                    ''.join(task_list)]
            else:
                script_lines = [
                    header,
                    'CORES={0}'.format(ncores),
                    "parallel -v -j${CORES} << EOF",
                    ''.join(task_list[chunk * chunk_size:chunk *
                                      chunk_size + chunk_size]),
                    'EOF']
            script = open(scriptfile, "w")
            script.write('\n'.join(script_lines))
            script.close()
            job_scripts.append(scriptfile)

    # execute the job script(s)
    for script in job_scripts:
        os.chmod(script, os.stat(script).st_mode | stat.S_IXUSR)
        if system == 'sge' or system == 'pbs':
            if verbose:
                print("Running: qsub {0}".format(script))
            if dry_run:
                continue
            return_code = subprocess.call(['qsub', script])
            if return_code:
                sys.exit("qsub call " +
                         "returned error code {0}".format(return_code))
        else:
            logfile = "{0}/{1}.log".format(logdir, job_name)
            if verbose:
                print("Launching jobscript. Output to {0}".format(logfile))
            if dry_run:
                continue
            run_command(script, logfile=logfile)
