#!/usr/bin/env python
# Licensed under a 3-clause BSD style license - see LICENSE.rst

from __future__ import (absolute_import, unicode_literals, division,
                        print_function)

from maltpynt.mp_io import mp_get_file_extension, MP_FILE_EXTENSION
from maltpynt.mp_calibrate import mp_calibrate
import argparse
import logging
from multiprocessing import Pool


def mp_calib_wrap(args):
    f, outname, rmf = args
    return mp_calibrate(f, outname, rmf)

description = ('Calibrate clean event files by associating the correct '
               'energy to each PI channel. Uses either a specified rmf file '
               'or (for NuSTAR only) an rmf file from the CALDB')
parser = argparse.ArgumentParser(description=description)

parser.add_argument("files", help="List of files", nargs='+')
parser.add_argument("-r", "--rmf", help="rmf file used for calibration",
                    default=None, type=str)
parser.add_argument("-o", "--overwrite",
                    help="Overwrite; default: no",
                    default=False,
                    action="store_true")
parser.add_argument("--loglevel",
                    help=("use given logging level (one between INFO, "
                          "WARNING, ERROR, CRITICAL, DEBUG; default:WARNING)"),
                    default='WARNING',
                    type=str)
parser.add_argument("--debug", help="use DEBUG logging level",
                    default=False, action='store_true')
parser.add_argument("--nproc",
                    help=("Number of processors to use"),
                    default=1,
                    type=int)

if __name__ == '__main__':
    args = parser.parse_args()
    files = args.files

    if args.debug:
        args.loglevel = 'DEBUG'

    numeric_level = getattr(logging, args.loglevel.upper(), None)
    logging.basicConfig(filename='MPcalibrate.log', level=numeric_level,
                        filemode='w')

    funcargs = []
    for i_f, f in enumerate(files):
        outname = f
        if args.overwrite is False:
            outname = f.replace(mp_get_file_extension(f), '_calib' +
                                MP_FILE_EXTENSION)
        funcargs.append([f, outname, args.rmf])

    pool = Pool(processes=args.nproc)
    for i in pool.imap_unordered(mp_calib_wrap, funcargs):
        pass
    pool.close()
