#!/usr/bin/env python

"""
MPI wrapper for redrock
"""

from __future__ import absolute_import, division, print_function
import sys, os, glob, time, subprocess, re
import argparse
import numpy as np
from astropy.io import fits
from redrock.external import desi
from redrock.utils import nersc_login_node

def weighted_partition(weights, n):
    '''
    Partition `weights` into `n` groups with approximately same sum(weights)

    Args:
        weights: array-like weights
        n: number of groups

    Returns (groups, groupweights):
        * groups: list of lists of indices of weights for each group
        * groupweights: sum of weights assigned to each group

    Notes:
        similar to `redrock.utils.distribute_work`, which was written
            independently; these have not yet been compared.
        within each group, the weights are sorted largest to smallest
    '''
    sumweights = np.zeros(n, dtype=float)
    groups = list()
    for i in range(n):
        groups.append(list())
    weights = np.asarray(weights)
    for i in np.argsort(-weights):
        j = np.argmin(sumweights)
        groups[j].append(i)
        sumweights[j] += weights[i]

    return groups, np.array([np.sum(x) for x in sumweights])

def spectra2outfiles(specfiles, inprefix, outprefix, ext=None, outdir=None):
    '''
    Convert a list of input spectra files to a list of output files

    Args:
        specfiles: list of spectra filepaths
        inprefix: input file prefix, e.g. 'spectra' or 'spPlate'
        outprefix: output file prefix, e.g. 'redrock'

    Options:
        ext: output extension, e.g. 'h5' (default same as input extension)
        outdir: output directory (default same as each input file directory)

    Returns:
        array of output filepaths

    Example::

        spectra2outfiles(['/a/b/blat-1.fits', '/c/d/blat-2.fits'], \
            inprefix='blat', outprefix='foo', ext='h5', outdir='/tmp/')
        --> array(['/tmp/foo-1.h5', '/tmp/foo-2.h5'], dtype='<U13')
    '''
    outfiles = list()
    for specfile in specfiles:
        dirname, basename = os.path.split(specfile)
        outfile = basename.replace(inprefix, outprefix)
        if ext is not None:
            outfile = outfile.replace('.fits', '.'+ext)

        if outdir is None:
            outfiles.append(os.path.join(dirname, outfile))
        else:
            outfiles.append(os.path.join(outdir, outfile))

    return np.array(outfiles)

def find_specfiles(reduxdir, outdir=None, prefix='spectra', avoiddir=None):
    '''
    Returns list of spectra files under reduxdir that need to be processed

    Options:
        reduxdir: path to redux directory
        outdir: path to output directory [default to same dir as inputs]
        prefix: filename prefix, e.g. 'spectra' or 'spPlate'
        avoiddir: subdirectory *not* to traverse

    Returns:
        list of spectra files to process

    Notes:
        Recursively walks directories under `reduxdir` looking for files
        matching prefix*.fits.  Looks for redrock*.h5 and redrock*.fits in the
        same directory as each spectra file, or in outdir if specified.
    '''
    # print("looking for spectra files under {}".format(reduxdir))

    if avoiddir is not None:
        avoiddir = os.path.normpath(avoiddir)

    specfiles = list()
    for dirpath, dirnames, filenames in os.walk(reduxdir, followlinks=True, topdown=True):
        if os.path.normpath(dirpath) == avoiddir:
            while True:
                try:
                    dirnames.pop()
                except IndexError:
                    break

        for filename in filenames:
            if filename.startswith(prefix) and filename.endswith('.fits'):
                specfiles.append(os.path.join(dirpath, filename))

    if len(specfiles) == 0:
        raise IOError('no specfiles found')
    # else:
    #     print('Found {} spectra files'.format(len(specfiles)))

    rrfiles = spectra2outfiles(specfiles, prefix, 'redrock', outdir=outdir)
    detailsfiles = spectra2outfiles(specfiles, prefix, 'rrdetails', outdir=outdir, ext='h5')

    npix = len(specfiles)
    todo = np.ones(npix, dtype=bool)
    for i in range(npix):
        if os.path.exists(rrfiles[i]) and os.path.exists(detailsfiles[i]):
            todo[i] = False

    return np.array(specfiles)[todo]

def group_specfiles(specfiles, maxnodes=256, comm=None):
    '''
    Group specfiles to balance runtimes

    Args:
        specfiles: list of spectra filepaths

    Options:
        maxnodes: split the spectra into this number of nodes
        comm: MPI communicator

    Returns (groups, ntargets, grouptimes):
      * groups: list of lists of indices to specfiles
      * list of number of targets per group
      * grouptimes: list of expected runtimes for each group
    '''
    if comm is None:
        rank, size = 0, 1
    else:
        rank, size = comm.rank, comm.size

    npix = len(specfiles)
    pixgroups = np.array_split(np.arange(npix), size)
    ntargets = np.zeros(len(pixgroups[rank]), dtype=int)
    for i, j in enumerate(pixgroups[rank]):
        fm = fits.getdata(specfiles[j], 'FIBERMAP')
        ntargets[i] = len(np.unique(fm['TARGETID']))

    if comm is not None:
        ntargets = comm.gather(ntargets)
        if rank == 0:
            ntargets = np.concatenate(ntargets)
        ntargets = comm.bcast(ntargets, root=0)

    runtimes = 30 + 0.4*ntargets

    #- aim for 25 minutes, but don't exceed maxnodes number of nodes
    if comm is not None:
        numnodes = comm.size
    else:
        numnodes = min(maxnodes, int(np.ceil(np.sum(runtimes)/(25*60))))

    groups, grouptimes = weighted_partition(runtimes, numnodes)
    ntargets = np.array([np.sum(ntargets[ii]) for ii in groups])
    return groups, ntargets, grouptimes

def backup_logs(logfile):
    '''
    Move logfile -> logfile.0 or logfile.1 or logfile.n as needed

    TODO: make robust against logfile.abc also existing
    '''
    logfiles = glob.glob(logfile+'.*')
    newlog = logfile+'.'+str(len(logfiles))
    assert not os.path.exists(newlog)
    os.rename(logfile, newlog)
    return newlog

def plan(args, comm=None):
    t0 = time.time()
    if comm is None:
        rank, size = 0, 1
    else:
        rank, size = comm.rank, comm.size

    if rank == 0:
        if args.datatype == 'boss':
            avoiddir = os.path.abspath(os.path.join(args.reduxdir, 'spectra'))
        else:
            avoiddir = None

        specfiles = find_specfiles(args.reduxdir, args.outdir,
                prefix=args.prefix, avoiddir=avoiddir)
    else:
        specfiles = None

    if comm is not None:
        specfiles = comm.bcast(specfiles, root=0)

    if len(specfiles) == 0:
        if rank == 0:
            print('All specfiles processed')
        return list(), list(), list()

    if args.datatype == 'desi':
        groups, ntargets, grouptimes = group_specfiles(specfiles, args.maxnodes, comm=comm)
    elif args.datatype == 'boss':
        #- BOSS files all have the same number of spectra, so no load balancing
        groups = np.array_split(np.arange(len(specfiles)), args.maxnodes)
        ntargets = [1000*len(x) for x in groups]
        grouptimes = [30 + 250*len(x) for x in groups]
    else:
        raise ValueError('Unknown --datatype {}'.format(args.datatype))

    if args.plan and rank == 0:
        plantime = time.time() - t0
        if plantime + np.max(grouptimes) <= (30*60):
            queue = 'debug'
        else:
            queue = 'regular'

        numnodes = len(groups)

        if os.getenv('NERSC_HOST') == 'cori':
            maxproc = 64
        elif os.getenv('NERSC_HOST') == 'edison':
            maxproc = 48
        else:
            maxproc = 8

        if args.mp is None:
            args.mp = maxproc // 2

        #- scale longer if purposefullying using fewer cores (e.g. for memory)
        if args.mp < maxproc // 2:
            scale = (maxproc//2) / args.mp
            grouptimes *= scale

        jobtime = int(1.15 * (plantime + np.max(grouptimes)))
        jobhours = jobtime // 3600
        jobminutes = (jobtime - jobhours*3600) // 60
        jobseconds = jobtime - jobhours*3600 - jobminutes*60

        print('#!/bin/bash')
        print('#SBATCH -N {}'.format(numnodes))
        print('#SBATCH -q {}'.format(queue))
        print('#SBATCH -J redrock')
        if os.getenv('NERSC_HOST') == 'cori':
            print('#SBATCH -C haswell')
        print('#SBATCH -t {:02d}:{:02d}:{:02d}'.format(jobhours, jobminutes, jobseconds))
        print()
        print('# {} pixels with {} targets'.format(len(specfiles), np.sum(ntargets)))
        ### print('# plan time {:.1f} minutes'.format(plantime / 60))
        print('# Using {} nodes in {} queue'.format(numnodes, queue))
        print('# expected rank runtimes ({:.1f}, {:.1f}, {:.1f}) min/mid/max minutes'.format(
            np.min(grouptimes)/60, np.median(grouptimes)/60, np.max(grouptimes)/60
        ))
        ibiggest = np.argmax(grouptimes)
        print('# Largest node has {} specfile(s) with {} total targets'.format(
            len(groups[ibiggest]), ntargets[ibiggest]))

        print()
        print('export OMP_NUM_THREADS=1')
        print('unset OMP_PLACES')
        print('unset OMP_PROC_BIND')
        print('export MPICH_GNI_FORK_MODE=FULLCOPY')
        print()
        print('nodes=$SLURM_JOB_NUM_NODES')
        rrcmd = '{} --mp {} --reduxdir {}'.format(
            os.path.abspath(__file__), args.mp, args.reduxdir)
        if args.outdir is not None:
            rrcmd += ' --outdir {}'.format(os.path.abspath(args.outdir))
        print('srun -N $nodes -n $nodes -c {} {}'.format(maxproc, rrcmd))

    return specfiles, groups, grouptimes

def run_redrock(args, comm=None):
    if comm is None:
        rank, size = 0, 1
    else:
        rank, size = comm.rank, comm.size

    args.maxnodes = min(args.maxnodes, size)

    t0 = time.time()
    if rank == 0:
        print('Starting at {}'.format(time.asctime()))

    specfiles, groups, grouptimes = plan(args, comm=comm)

    if rank == 0:
        print('Initial setup took {:.1f} sec'.format(time.time() - t0))

    sys.stdout.flush()
    if comm is not None:
        groups = comm.bcast(groups, root=0)
        specfiles = comm.bcast(specfiles, root=0)

    assert len(groups) == size
    assert len(np.concatenate(groups)) == len(specfiles)

    pixels = np.array([int(os.path.basename(os.path.dirname(x))) for x in specfiles])
    rrfiles = spectra2outfiles(specfiles, args.prefix, 'redrock', outdir=args.outdir)
    detailsfiles = spectra2outfiles(specfiles, args.prefix, 'rrdetails', outdir=args.outdir, ext='h5')

    for i in groups[rank]:
        print('---- rank {} pix {} {}'.format(rank, pixels[i], time.asctime()))
        sys.stdout.flush()

        if args.datatype == 'desi':
            cmd = 'rrdesi {}'.format(specfiles[i])
        elif args.datatype == 'boss':
            cmd = 'rrboss --spplate {}'.format(specfiles[i])

        cmd += ' --nminima {}'.format(args.nminima)
        cmd += ' --outfile {} --details {}'.format(rrfiles[i], detailsfiles[i])
        logfile = rrfiles[i].replace('.fits', '.log')
        assert logfile != rrfiles[i]

        if args.mp is not None:
            cmd += ' --mp {}'.format(args.mp)

        if args.archetypes is not None:
            cmd += ' --archetypes {}'.format(args.archetypes)

        print('Rank {} RUNNING {}'.format(rank, cmd))
        print('LOGGING to {}'.format(logfile))
        sys.stdout.flush()

        if args.dryrun:
            continue

        try:
            maxtries = 2
            for retry in range(maxtries):
                t1 = time.time()
                if os.path.exists(logfile):
                    backup_logs(logfile)
                #- memory leak?  Try making system call instead
                ### desi.rrdesi(cmd.split()[1:])
                with open(logfile, 'w') as log:
                    err = subprocess.call(cmd.split(), stdout=log, stderr=log)
                dt1 = time.time() - t1
                if err == 0:
                    print('FINISHED pix {} rank {} try {} in {:.1f} sec'.format(pixels[i], rank, retry, dt1))
                    for outfile in [rrfiles[i], detailsfiles[i]]:
                        if not os.path.exists(outfile):
                            print('ERROR pix {} missing {}'.format(outfile, rrfiles[i]))
                    break  #- don't need to retry
                else:
                    print('FAILED pix {} rank {} try {} in {:.1f} sec error code {}'.format(pixels[i], rank, retry, dt1, err))
                    if retry == maxtries-1:
                        print('FATAL pix {} failed {} times; giving up'.format(pixels[i], maxtries))
                    else:
                        time.sleep(np.random.uniform(1,5))

        except Exception as err:
            print('FAILED: pix {} rank {} raised an exception'.format(pixels[i], rank))
            import traceback
            traceback.print_exc()

    print('---- rank {} is done'.format(rank))
    sys.stdout.flush()

    if comm is not None:
        comm.barrier()

    if rank == 0 and not args.dryrun:
        for outfile in rrfiles:
            if not os.path.exists(outfile):
                print('ERROR missing {}'.format(outfile))

        print('all done at {}'.format(time.asctime()))

#-------------------------------------------------------------------------
def main():
    parser = argparse.ArgumentParser(usage = "wrap-redrock [options]")
    parser.add_argument("--reduxdir", type=str, required=True, help="input redux base directory")
    parser.add_argument("--outdir", type=str,  help="output directory")
    parser.add_argument("--archetypes", type=str,  help="archetypes directory")
    parser.add_argument("--mp", type=int,  help="number of multiprocessing processes per MPI rank")
    parser.add_argument("--nompi", action="store_true", help="Do not use MPI parallelism")
    parser.add_argument("--dryrun", action="store_true", help="Generate but don't run commands")
    parser.add_argument("--maxnodes", type=int, default=256, help="maximum number of nodes to use")
    parser.add_argument("--nminima", type=int, default=3, help="number of zchi2 minima to keep per template type")
    parser.add_argument("--plan", action="store_true", help="plan how many nodes to use and pixel distribution")
    parser.add_argument("--datatype", type=str, default='desi',
        help="desi (default) or boss", choices=['desi', 'boss'])
    parser.add_argument("--prefix", type=str,  help="spectra file name prefix")
    args = parser.parse_args()

    if args.prefix is None:
        if args.datatype == 'desi':
            args.prefix = 'spectra'
        elif args.datatype == 'boss':
            args.prefix = 'spPlate'
        else:
            print('ERROR: unknown input prefix for datatype {}; use --prefix'.format(args.datatype))
            sys.exit(1)

    if args.nompi or nersc_login_node():
        comm = None
    else:
        try:
            from mpi4py import MPI
            comm = MPI.COMM_WORLD
        except ImportError:
            comm = None

    if args.plan:
        plan(args, comm=comm)
    else:
        #- Create output directory if needed
        if (args.outdir is not None) and ((comm is None) or (comm.rank == 0)):
            os.makedirs(args.outdir, exist_ok=True)

        #- All ranks wait for directory to be created
        if comm is not None:
            comm.barrier()

        run_redrock(args, comm=comm)

if __name__ == '__main__':
    main()
