#!/usr/bin/env python
# Licensed under a 3-clause BSD style license - see LICENSE.rst

from __future__ import (absolute_import, unicode_literals, division,
                        print_function)

from maltpynt.mp_lcurve import (mp_lcurve_from_events, mp_scrunch_lightcurves,
                                mp_join_lightcurves, mp_lcurve_from_fits,
                                mp_lcurve_from_txt)
import argparse
import numpy as np
import logging
from multiprocessing import Pool
import functools

description = ('Create lightcurves starting from event files. It is '
               'possible to specify energy or channel filtering options')
parser = argparse.ArgumentParser(description=description)

parser.add_argument("files", help="List of files", nargs='+')

parser.add_argument("-b", "--bintime", type=float, default=1,
                    help="Bin time; if negative, negative power of 2")
parser.add_argument("--safe-interval", nargs=2, type=float,
                    default=[0, 0],
                    help="Interval at start and stop of GTIs used" +
                    " for filtering")
parser.add_argument("--pi-interval", type=int, default=[-1, -1],
                    nargs=2,
                    help="PI interval used for filtering")
parser.add_argument('-e', "--e-interval", type=float, default=[-1, -1],
                    nargs=2,
                    help="Energy interval used for filtering")
parser.add_argument("-s", "--scrunch",
                    help="Create scrunched light curve (single channel)",
                    default=False,
                    action="store_true")
parser.add_argument("-j", "--join",
                    help="Create joint light curve (multiple channels)",
                    default=False,
                    action="store_true")
parser.add_argument("-g", "--gti-split",
                    help="Split light curve by GTI",
                    default=False,
                    action="store_true")
parser.add_argument("--minlen",
                    help="Minimum length of acceptable GTIs (default:4)",
                    default=4, type=float)
parser.add_argument("--ignore-gtis",
                    help="Ignore GTIs",
                    default=False,
                    action="store_true")
parser.add_argument("-d", "--outdir", type=str, default=None,
                    help='Output directory')
parser.add_argument("--loglevel",
                    help=("use given logging level (one between INFO, "
                          "WARNING, ERROR, CRITICAL, DEBUG; "
                          "default:WARNING)"),
                    default='WARNING',
                    type=str)
parser.add_argument("--nproc",
                    help=("Number of processors to use"),
                    default=1,
                    type=int)
parser.add_argument("--debug", help="use DEBUG logging level",
                    default=False, action='store_true')
parser.add_argument("--noclobber", help="Do not overwrite existing files",
                    default=False, action='store_true')
parser.add_argument("--fits-input",
                    help="Input files are light curves in FITS format",
                    default=False, action='store_true')
parser.add_argument("--txt-input",
                    help="Input files are light curves in txt format",
                    default=False, action='store_true')

if __name__ == '__main__':
    args = parser.parse_args()

    if args.debug:
        args.loglevel = 'DEBUG'
    bintime = args.bintime

    numeric_level = getattr(logging, args.loglevel.upper(), None)
    logging.basicConfig(filename='MPlcurve.log', level=numeric_level,
                        filemode='w')

    infiles = args.files
    safe_interval = args.safe_interval
    pi_interval = np.array(args.pi_interval)
    e_interval = np.array(args.e_interval)

    # ------ Use functools.partial to wrap mp_lcurve* with relevant keywords---
    if args.fits_input:
        wrap_fun = functools.partial(mp_lcurve_from_fits,
                                     noclobber=args.noclobber)
    elif args.txt_input:
        wrap_fun = functools.partial(mp_lcurve_from_txt,
                                     noclobber=args.noclobber)
    else:
        wrap_fun = functools.partial(
            mp_lcurve_from_events, safe_interval=safe_interval,
            pi_interval=pi_interval,
            e_interval=e_interval,
            min_length=args.minlen,
            gti_split=args.gti_split,
            ignore_gtis=args.ignore_gtis,
            bintime=bintime, outdir=args.outdir,
            noclobber=args.noclobber)
    # -------------------------------------------------------------------------
    outfiles = []

    pool = Pool(processes=args.nproc)
    for i in pool.imap_unordered(wrap_fun, args.files):
        outfiles.append(i)
    pool.close()

    logging.debug(outfiles)

    if args.scrunch:
        mp_scrunch_lightcurves(outfiles)

    if args.join:
        mp_join_lightcurves(outfiles)
