#!python
# -*- coding: utf-8 -*-

"""This script allows you to submit Slurm batch scripts via `submit`.
After your jobs are completed, you can check for failed jobs and save
these to a rescue file via `rescue`.

"""
import json
import os
import time

import pyssub.sbatch
import pyssub.scmd
import pyssub.shist


# ---Script's subcommands------------------------------------------------------
def submit(inputfile, outputfile,
           rescue=None,
           nmax=1000,
           wait=120,
           partition=None):
    """Submit Slurm batch scripts.

    Successively submit the batch scripts listed in the given input file
    into the Slurm queue, requesting the given partition and checking
    that not more jobs than `nmax` are queuing. Wait for `wait` seconds
    before trying to submit more jobs. Save name and ID of the
    submitted jobs in the given output file. If a rescue file is given,
    which contains failed jobs from a previous submission, only these
    jobs are resubmitted.

    """
    with open(inputfile, "r") as stream:
        scripts = json.load(
            stream, object_hook=pyssub.sbatch.SBatchScriptDecoder())

    if rescue is not None:
        rescue = pyssub.shist.load(rescue)
        scripts = {k: v for k, v in scripts.items() if k in rescue}

    pyssub.shist.save(outputfile, jobs=sbatch(scripts, partition, nmax, wait))


def rescue(inputfile, outputfile):
    """Save a rescue file to disk.

    Load the submitted jobs listed in the given input file from disk.
    These should have completed, otherwise they will be ignored. Check
    which of these jobs have failed and save those to disk, using the
    path to the given output file.

    """
    jobs = pyssub.shist.load(inputfile)
    pyssub.shist.save(outputfile, jobs=pyssub.scmd.failed(jobs))


# ---Slurm batch script submission---------------------------------------------
def sbatch(scripts, partition=None, nmax=1000, wait=120):
    """Successively submit Slurm batch scripts.

    This function allows the successive submission of a collection of
    scripts into the Slurm queue, checking that not more jobs
    than `nmax` are queuing.

    Parameters
    ----------
    scripts : dict(str, SBatchScript)
        Mapping of job names to Slurm batch scripts
    partition : str, optional
        Partition for resource allocation
    nmax : int, optional
        Maximum allowed number of queuing jobs
    wait : int, optional
        Number of seconds before next submission

    Returns
    -------
    dict(str, int)
        Mapping of job names to job IDs

    """
    # We do not want to change the given dictionary.
    scripts = dict(scripts)

    jobs = {}
    while len(scripts) > 0:
        nnew = min(
            nmax - pyssub.scmd.numjobs(str(os.getuid()), partition),
            len(scripts))

        for _ in range(nnew):
            name, script = scripts.popitem()
            jobs[name] = pyssub.scmd.submit(script, partition)

        if len(scripts) > 0:
            time.sleep(wait)

    return jobs


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description=__doc__)

    parent = argparse.ArgumentParser(add_help=False)

    parent.add_argument(
        "--in",
        nargs="?",
        type=str,
        help="path to input file",
        required=True,
        metavar="PATH",
        dest="inputfile")

    parent.add_argument(
        "--out",
        nargs="?",
        type=str,
        help="path to output file",
        required=True,
        metavar="PATH",
        dest="outputfile")

    subparsers = parser.add_subparsers()

    subparser = subparsers.add_parser(
        "submit", description=submit.__doc__, parents=[parent])

    subparser.add_argument(
        "--rescue",
        nargs="?",
        type=str,
        help="path to rescue file",
        metavar="PATH",
        dest="rescue")

    subparser.add_argument(
        "--nmax",
        nargs="?",
        type=int,
        help="maximum allowed number of queuing jobs (%(default)s)",
        default=1000,
        metavar="NUM")

    subparser.add_argument(
        "--wait",
        nargs="?",
        type=int,
        help="number of seconds before next submission (%(default)s)",
        default=120,
        metavar="SEC")

    subparser.add_argument(
        "--partition",
        nargs="?",
        type=str,
        help="request this partition",
        default=None,
        metavar="NAME")

    subparser.set_defaults(function=submit)

    subparser = subparsers.add_parser(
        "rescue", description=rescue.__doc__, parents=[parent])

    subparser.set_defaults(function=rescue)

    kwargs = vars(parser.parse_args())

    function = kwargs.pop("function")
    function(**kwargs)
