#!python
"""
MPI wrapper for fastphot and fastspec.

"""
import os, time
import numpy as np
from astropy.table import Table

from fastspecfit.logger import log


def build_cmdargs(args, redrockfile, outfile, sample=None, fastphot=False,
                  input_redshifts=False):
    """Build the set of command-line arguments.

    """
    # With --makeqa the desired output directories are in the 'redrockfiles'.
    if args.makeqa:
        cmd = 'fastqa'
        cmdargs = f'{outfile} -o={redrockfile} --mp={args.mp}'
    else:
        if fastphot:
            cmd = 'fastphot'
        else:
            cmd = 'fastspec'
        cmdargs = f'{redrockfile} -o={outfile} --mp={args.mp}'

        if args.ignore_quasarnet:
            cmdargs += ' --ignore-quasarnet'
        if args.ignore_quasarnet:
            cmdargs += ' --ignore-photometry'
        if args.no_smooth_continuum:
            cmdargs += ' --no-smooth-continuum'
        if args.templates:
            cmdargs += f' --templates={args.templates}'
        if args.templateversion:
            cmdargs += f' --templateversion={args.templateversion}'
        if args.fphotodir:
            cmdargs += f' --fphotodir={args.fphotodir}'
        if args.fphotofile:
            cmdargs += f' --fphotofile={args.fphotofile}'
        if args.nmonte:
            cmdargs += f' --nmonte={args.nmonte}'
        if args.seed:
            cmdargs += f' --seed={args.seed}'

    if args.verbose:
        cmdargs += ' --verbose'
    if args.ntargets is not None:
        cmdargs += f' --ntargets={args.ntargets}'
    if args.firsttarget is not None:
        cmdargs += f' --firsttarget={args.firsttarget}'

    if sample is not None:
        # assume healpix coadds; find the targetids to process -- fragile
        _, survey, program, healpix = os.path.basename(redrockfile).split('-')
        healpix = int(healpix.split('.')[0])
        I = (sample['SURVEY'] == survey) * (sample['PROGRAM'] == program) * (sample['HEALPIX'] == healpix)
        targetids = ','.join(sample[I]['TARGETID'].astype(str))
        cmdargs += f' --targetids={targetids}'
        if input_redshifts:
            inputz = ','.join(sample[I]['Z'].astype(str))
            cmdargs += f' --input-redshifts={inputz}'
    else:
        if args.targetids is not None:
            cmdargs += f' --targetids={args.targetids}'

    if args.makeqa:
        logfile = os.path.join(redrockfile, os.path.basename(outfile).replace('.gz', '').replace('.fits', '.log'))
    else:
        logfile = outfile.replace('.gz', '').replace('.fits', '.log')

    #if args.makeqa and args.plan:
    #    log.info(f'Still need to make {np.sum(ntargets)} QA figure(s) across {len(groups)} input file(s)')
    #    return

    return cmd, cmdargs, logfile


def run_fastspecfit(args, comm=None, fastphot=False, specprod_dir=None, makeqa=False,
                    sample=None, input_redshifts=False, outdir_data='.', templates=None,
                    templateversion=None, fphotodir=None, fphotofile=None):
    """Main wrapper to run fastspec, fastphot, or fastqa.

    Top-level MPI paraellelization is over (e.g., healpix) files, but there is
    another level of parallelization which makes use of subcommunicators.

    For example, calling `mpi-fastspecfit` with 8 MPI tasks and --mp=4 will
    result in two (8/4) healpix files being processed simultaneously
    (specifically, by ranks 0 and 4) and then a further level of
    parallelization over the objects in each of those files specifically, but
    subranks (0, 1, 2, 3) and (4, 5, 6, 7), respectively.

    """
    import sys
    from desispec.parallel import stdouterr_redirected, weighted_partition
    from fastspecfit.mpi import plan

    if comm:
        rank = comm.rank
        size = comm.size
    else:
        rank, size = 0, 1
        subcomm = None

    if rank == 0:
        t0 = time.time()
    _, all_redrockfiles, all_outfiles, all_ntargets = plan(
        comm=comm, specprod=args.specprod, specprod_dir=specprod_dir,
        coadd_type=args.coadd_type, survey=args.survey, program=args.program,
        healpix=args.healpix, tile=args.tile, night=args.night,
        makeqa=args.makeqa, mp=args.mp, fastphot=fastphot,
        outdir_data=outdir_data, overwrite=args.overwrite, sample=sample)
    if rank == 0:
        log.info(f'Planning took {time.time() - t0:.2f} sec')
        # If no work is left to do, let the other ranks know so they can return
        # politely.
        if len(all_redrockfiles) == 0:
            alldone = True
        else:
            alldone = False
    else:
        alldone = False

    if comm:
        alldone = comm.bcast(alldone, root=0)

    if alldone:
        return

    if comm:
        # Split the MPI.COMM_WORLD communicator into subcommunicators (of size
        # args.mp) so we can MPI-parallelize over objects.
        allranks = np.arange(comm.size)
        if args.purempi:
            colors = np.arange(comm.size) // args.mp
            color = rank // args.mp
        else:
            colors = np.arange(comm.size)
            color = rank

        subcomm = comm.Split(color=color, key=rank)

        if rank == 0:
            if args.purempi:
                subranks0 = allranks[::args.mp] # rank=0 in each subcomm
                log.info(f'Rank {rank}: dividing filelist into {len(subranks0):,d} sub-communicator(s) ' + \
                         f'(size={comm.size:,d}, mp={args.mp}).')
            else:
                subranks0 = allranks
                log.info(f'Rank {rank}: dividing filelist across {comm.size:,d} ranks.')
        else:
            subranks0 = None

        subranks0 = comm.bcast(subranks0, root=0)

        # Send the filelists and number of targets to each subrank0.
        if rank == 0:
            groups = weighted_partition(all_ntargets, len(subranks0))
            for irank in range(1, len(subranks0)):
                log.debug(f'Rank {rank} sending work to rank {subranks0[irank]}')
                comm.send(all_redrockfiles[groups[irank]], dest=subranks0[irank], tag=1)
                comm.send(all_outfiles[groups[irank]], dest=subranks0[irank], tag=2)
                comm.send(all_ntargets[groups[irank]], dest=subranks0[irank], tag=3)
            redrockfiles = all_redrockfiles[groups[rank]]
            outfiles = all_outfiles[groups[rank]]
            ntargets = all_ntargets[groups[rank]]
        else:
            if rank in subranks0:
                log.debug(f'Rank {rank}: received work from rank 0')
                redrockfiles = comm.recv(source=0, tag=1)
                outfiles = comm.recv(source=0, tag=2)
                ntargets = comm.recv(source=0, tag=3)

        # Each subrank0 sends work to the subranks it controls.
        if subcomm.rank == 0:
            subranks = allranks[np.isin(colors, color)]
            # process from smallest to largest
            srt = np.argsort(ntargets)#[::-1]
            redrockfiles = redrockfiles[srt]
            outfiles = outfiles[srt]
            ntargets = ntargets[srt]
            for irank in range(1, subcomm.size):
                log.debug(f'Subrank 0 (rank {rank}) sending work to subrank {irank} (rank {subranks[irank]})')
                subcomm.send(redrockfiles, dest=irank, tag=4)
                subcomm.send(outfiles, dest=irank, tag=5)
                subcomm.send(ntargets, dest=irank, tag=6)
        else:
            redrockfiles = subcomm.recv(source=0, tag=4)
            outfiles = subcomm.recv(source=0, tag=5)
            ntargets = subcomm.recv(source=0, tag=6)
    else:
        redrockfiles = all_redrockfiles
        outfiles = all_outfiles
        ntargets = all_ntargets
    #print(f'Rank={comm.rank}, subrank={subcomm.rank}, redrockfiles={redrockfiles}, ntargets={ntargets}')

    for redrockfile, outfile, ntarget in zip(redrockfiles, outfiles, ntargets):
        if subcomm:
            if subcomm.rank == 0:
                if args.purempi:
                    log.debug(f'Rank {rank} (subrank {subcomm.rank}) started ' + \
                              f'at {time.asctime()}')
                else:
                    log.debug(f'Rank {rank} started at {time.asctime()}')
        elif rank == 0:
            log.debug(f'Rank {rank} started at {time.asctime()}')

        if args.makeqa:
            from fastspecfit.qa import fastqa as fast
        else:
            if fastphot:
                from fastspecfit.fastspecfit import fastphot as fast
            else:
                from fastspecfit.fastspecfit import fastspec as fast

        cmd, cmdargs, logfile = build_cmdargs(args, redrockfile, outfile, sample=sample,
                                              fastphot=fastphot, input_redshifts=input_redshifts)

        if subcomm:
            if subcomm.rank == 0:
                if args.purempi:
                    log.info(f'Rank {rank} (nsubrank={subcomm.size}): ' + \
                             f'ntargets={ntarget}: {cmd} {cmdargs}')
                else:
                    log.info(f'Rank {rank} ntargets={ntarget}: {cmd} {cmdargs}')
        elif rank == 0:
            log.info(f'Rank {rank}: ntargets={ntarget}: {cmd} {cmdargs}')

        if args.dry_run:
            continue

        try:
            if subcomm:
                if subcomm.rank == 0:
                    t1 = time.time()
                    outdir = os.path.dirname(logfile)
                    if not os.path.isdir(outdir):
                        os.makedirs(outdir, exist_ok=True)
            elif rank == 0:
                t1 = time.time()
                outdir = os.path.dirname(logfile)
                if not os.path.isdir(outdir):
                    os.makedirs(outdir, exist_ok=True)

            if args.nolog:
                if args.purempi:
                    err = fast(args=cmdargs.split(), comm=subcomm)
                else:
                    err = fast(args=cmdargs.split(), comm=None)
            else:
                with stdouterr_redirected(to=logfile, overwrite=args.overwrite, comm=subcomm):
                    if args.purempi:
                        err = fast(args=cmdargs.split(), comm=subcomm)
                    else:
                        err = fast(args=cmdargs.split(), comm=None)

            if subcomm:
                if subcomm.rank == 0:
                    log.info(f'Rank {rank} done in {time.time() - t1:.2f} sec')
                    if err != 0:
                        if not os.path.exists(outfile):
                            log.warning(f'Rank {rank} missing {outfile}')
                            raise IOError
            elif rank == 0:
                log.info(f'Rank {rank} done in {time.time() - t1:.2f} sec')
                if err != 0:
                    if not os.path.exists(outfile):
                        log.warning(f'Rank {rank} missing {outfile}')
                        raise IOError

        except:
            log.warning(f'Rank {rank} raised an exception')
            import traceback
            traceback.print_exc()

    if subcomm:
        if subcomm.rank == 0:
            log.debug(f'Rank {rank} is done')
    elif rank == 0:
        log.debug(f'Rank {rank} is done')

    if comm:
        comm.barrier()

    if rank == 0 and not args.dry_run:
        for outfile in outfiles:
            if not os.path.exists(outfile):
                log.warning(f'Missing {outfile}')

        log.info(f'All done at {time.asctime()}')


def main():
    """Main wrapper on fastphot and fastspec.

    Currently only knows about SV1 observations.

    """
    import argparse
    from fastspecfit.mpi import plan

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--coadd-type', type=str, default='healpix', choices=['healpix', 'cumulative', 'pernight', 'perexp'],
                        help='Specify which type of spectra/zbest files to process.')
    parser.add_argument('--specprod', type=str, default='loa', help='Spectroscopic production to process.')

    parser.add_argument('--healpix', type=str, default=None, help='Comma-separated list of healpixels to process.')
    parser.add_argument('--survey', type=str, default='main,special,cmx,sv1,sv2,sv3', help='Survey to process.')
    parser.add_argument('--program', type=str, default='bright,dark,other,backup', help='Program to process.') # backup not supported
    parser.add_argument('--tile', default=None, type=str, nargs='*', help='Tile(s) to process.')
    parser.add_argument('--night', default=None, type=str, nargs='*', help='Night(s) to process (ignored if coadd-type is cumulative).')

    parser.add_argument('--samplefile', default=None, type=str, help='Full path to sample (FITS) file with {survey,program,healpix,targetid}.')
    parser.add_argument('--input-redshifts', action='store_true', help='Only used with --samplefile; if set, fit with redshift "Z" values.')
    parser.add_argument('--input-seeds', type=str, default=None, help='Comma-separated list of input random-number seeds corresponding to the (required) --targetids input.')
    parser.add_argument('--nmonte', type=int, default=10, help='Number of Monte Carlo realizations.')
    parser.add_argument('--seed', type=int, default=1, help='Random seed for Monte Carlo reproducibility; ignored if --input-seeds is passed.')

    parser.add_argument('--mp', type=int, default=1, help='Number of multiprocessing processes per MPI rank or node.')
    parser.add_argument('-n', '--ntargets', type=int, help='Number of targets to process in each file.')
    parser.add_argument('--firsttarget', type=int, default=0, help='Index of first object to to process in each file, zero-indexed.') 
    parser.add_argument('--targetids', type=str, default=None, help='Comma-separated list of TARGETIDs to process.')

    parser.add_argument('--fastphot', action='store_true', help='Fit the broadband photometry.')

    parser.add_argument('--templateversion', type=str, default=None, help='Template version number.')
    parser.add_argument('--templates', type=str, default=None, help='Optional full path and filename to the templates.')
    parser.add_argument('--fphotodir', type=str, default=None, help='Top-level location of the source photometry.')
    parser.add_argument('--fphotofile', type=str, default=None, help='Photometric information file.')

    parser.add_argument('--merge', action='store_true', help='Merge all individual catalogs (for a given survey and program) into one large file.')
    parser.add_argument('--mergedir', type=str, help='Output directory for merged catalogs.')
    parser.add_argument('--merge-suffix', type=str, help='Filename suffix for merged catalogs.')
    parser.add_argument('--mergeall', action='store_true', help='Merge all the individual merged catalogs into a single merged catalog.')
    parser.add_argument('--mergeall-main', action='store_true', help='Merge all the main catalogs.')
    parser.add_argument('--mergeall-sv', action='store_true', help='Merge all the SV catalogs.')
    parser.add_argument('--mergeall-special', action='store_true', help='Merge all the special catalogs.')
    parser.add_argument('--makeqa', action='store_true', help='Build QA in parallel.')

    parser.add_argument('--ignore-quasarnet', default=False, action='store_true', help='Do not use QuasarNet to improve QSO redshifts.')
    parser.add_argument('--ignore-photometry', default=False, action='store_true', help='Ignore the broadband photometry during model fitting.')
    parser.add_argument('--no-smooth-continuum', default=False, action='store_true', help='Do not fit the smooth continuum.')

    parser.add_argument('--verbose', action='store_true', help='More verbose output.')
    parser.add_argument('--overwrite', action='store_true', help='Overwrite any existing output files.')
    parser.add_argument('--plan', action='store_true', help='Plan how many nodes to use and how to distribute the targets.')
    parser.add_argument('--profile', action='store_true', help='Write out profiling / timing files..')
    parser.add_argument('--nompi', action='store_true', help='Do not use MPI parallelism.')
    parser.add_argument('--purempi', action='store_true', help='Use only MPI parallelism; no multiprocessing.')
    parser.add_argument('--nolog', action='store_true', help='Do not write to the log file.')
    parser.add_argument('--dry-run', action='store_true', help='Generate but do not run commands.')

    parser.add_argument('--outdir-data', default='$PSCRATCH/fastspecfit/data', type=str, help='Base output data directory.')

    args = parser.parse_args()

    specprod_dir = None
    outdir_data = os.path.expanduser(os.path.expandvars(args.outdir_data))

    if args.nompi:
        comm = None
    else:
        try:
            from mpi4py import MPI
            # needed when profiling; no effect otherwise
            # https://docs.linaroforge.com/24.0.6/html/forge/map/python_profiling/profile_python_script.html
            #MPI.Init_thread(MPI.THREAD_SINGLE)
            comm = MPI.COMM_WORLD
        except ImportError:
            comm = None

    if comm:
        rank = comm.rank
        if rank == 0 and args.purempi and comm.size > 1 and args.mp > 1 and comm.size < args.mp:
            log.warning(f'Number of MPI tasks {comm.size} should be >{args.mp} for MPI parallelism.')
    else:
        rank = 0
        # https://docs.nersc.gov/development/languages/python/parallel-python/#use-the-spawn-start-method
        if args.mp > 1 and 'NERSC_HOST' in os.environ:
            import multiprocessing
            multiprocessing.set_start_method('spawn')

    # If an input samplefile is provided, read and broadcast it.
    if args.samplefile is None:
        sample = None
    else:
        if args.coadd_type != 'healpix':
            errmsg = 'Input --samplefile is only currently compatible with --coadd-type="healpix"'
            log.critical(errmsg)
            raise NotImplementedError(errmsg)

        if rank == 0:
            import fitsio
            if not os.path.isfile(args.samplefile):
                log.warning(f'{args.samplefile} does not exist.')
                return
            try:
                sample = Table(fitsio.read(args.samplefile, columns=['SURVEY', 'PROGRAM', 'HEALPIX', 'TARGETID']))
                log.info(f'Read {len(sample)} rows from {args.samplefile}')
            except:
                if args.input_redshifts:
                    errmsg = f'Sample file {args.samplefile} with --input-redshifts missing required columns ' + \
                        '{SURVEY,PROGRAM,HEALPIX,TARGETID,Z}'
                else:
                    errmsg = f'Sample file {args.samplefile} missing required columns ' + \
                        '{SURVEY,PROGRAM,HEALPIX,TARGETID}'
                log.critical(errmsg)
                raise ValueError(errmsg)
        else:
            sample = Table()

        sample = comm.bcast(sample, root=0)

    # Parse some of the inputs.
    if args.samplefile is None and args.coadd_type == 'healpix':
        args.survey = args.survey.split(',')
        args.program = args.program.split(',')
        if args.healpix is not None:
            args.healpix = args.healpix.split(',')

    if args.mergeall_main or args.mergeall_sv or args.mergeall_special:
        args.mergeall = True

    if args.merge or args.mergeall:
        from fastspecfit.mpi import merge_fastspecfit

        # convenience code to make the super-merge catalogs, e.g., fastspec-iron-{main,special,sv}.fits
        if args.fastphot:
            fastprefix = 'fastphot'
        else:
            fastprefix = 'fastspec'

        if args.mergeall_main or args.mergeall_sv or args.mergeall_special:
            from glob import glob
            if args.mergedir is None:
                mergedir = os.path.join(outdir_data, args.specprod, 'catalogs')
            else:
                mergedir = args.mergedir

            if args.mergeall_main:
                args.merge_suffix = f'{args.specprod}-main'
                fastfiles_to_merge = sorted(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-main-*.fits')))
            elif args.mergeall_special:
                args.merge_suffix = f'{args.specprod}-special'
                fastfiles_to_merge = sorted(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-special-*.fits')))
            elif args.mergeall_sv:
                args.merge_suffix = f'{args.specprod}-sv'
                fastfiles_to_merge = sorted(list(set(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-*.fits'))) -
                                                 set(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-main.fits'))) -
                                                 set(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-special.fits'))) -
                                                 set(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-special-*.fits'))) -
                                                 set(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-main-*.fits'))) -
                                                 set(glob(os.path.join(mergedir, f'{fastprefix}-{args.specprod}-sv.fits')))))
            else:
                fastfiles_to_merge = None
        else:
            fastfiles_to_merge = None

        if args.samplefile is not None:
            merge_fastspecfit(specprod=args.specprod, specprod_dir=specprod_dir, coadd_type='healpix',
                              sample=sample, merge_suffix=args.merge_suffix,
                              outdir_data=outdir_data, fastfiles_to_merge=fastfiles_to_merge,
                              outsuffix=args.merge_suffix, mergedir=args.mergedir, overwrite=args.overwrite,
                              fastphot=args.fastphot, supermerge=args.mergeall, mp=args.mp)
        else:
            merge_fastspecfit(specprod=args.specprod, specprod_dir=specprod_dir, coadd_type=args.coadd_type,
                              survey=args.survey, program=args.program, healpix=args.healpix,
                              tile=args.tile, night=args.night, outdir_data=outdir_data,
                              fastfiles_to_merge=fastfiles_to_merge, outsuffix=args.merge_suffix,
                              mergedir=args.mergedir, overwrite=args.overwrite,
                              fastphot=args.fastphot, supermerge=args.mergeall, mp=args.mp)
        return


    if args.plan:
        plan(comm=comm, specprod=args.specprod, specprod_dir=specprod_dir,
             coadd_type=args.coadd_type, survey=args.survey, program=args.program,
             healpix=args.healpix, tile=args.tile, night=args.night,
             makeqa=args.makeqa, mp=args.mp, fastphot=args.fastphot,
             outdir_data=outdir_data, overwrite=args.overwrite,
             sample=sample)
    else:
        if args.profile:
            import cProfile
            import pstats

            profiler = cProfile.Profile()
            profiler.enable()

        run_fastspecfit(args, comm=comm, fastphot=args.fastphot, specprod_dir=specprod_dir,
                        makeqa=args.makeqa, outdir_data=outdir_data, sample=sample,
                        input_redshifts=args.input_redshifts, templates=args.templates,
                        templateversion=args.templateversion, fphotodir=args.fphotodir,
                        fphotofile=args.fphotofile)

        if args.profile:
            profiler.disable()

            outfile = os.path.join(outdir_data, f'profile_rank{rank}.prof')
            log.info(f'Writing {outfile}')
            profiler.dump_stats(outfile)
            with open(os.path.join(outdir_data, f'profile_rank{rank}.txt'), 'w') as F:
                stats = pstats.Stats(profiler, stream=F)
                stats.strip_dirs()
                stats.sort_stats('cumtime')
                stats.print_stats()
                stats.print_callees()

    if comm:
        MPI.Finalize()


if __name__ == '__main__':
    main()
