#!python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import os
import sys
import math
import numpy as np
import argparse
import pyspextools
from pyspextools.io.arf import Arf
from pyspextools.io.rmf import Rmf
from pyspextools.data.response import gaussrsp

import pyspextools.messages as message

from builtins import int

from future import standard_library

standard_library.install_aliases()

def main():
    '''This program can generate response files for new detectors with an arbitrary (Gaussian) redistribution
    function. The size of the response file and the respective energy ranges can be varied within the limits
    of the supplied effective area (ARF) file. The output format for this tool is OGIP.'''

    # Obtain command line arguments
    parser = genrsp_arguments()
    args = parser.parse_args()

    # Print header
    message.print_header(os.path.basename(__file__))

    # Set color in the terminal
    message.set_color(args.color)

    # Read the input ARF file
    message.proc_start('Reading ARF file')
    arf_in = Arf()
    stat = arf_in.read(args.arffile)
    message.proc_end(stat)

    # Create input array with bin centers
    x = (arf_in.LowEnergy + arf_in.HighEnergy) / 2.0

    # Define the new grid
    low = args.range[0]   # Lowest energy
    high = args.range[1]  # Highest energy
    step = 1E-3 * args.resolution / args.sampling # Energy stepsize (in keV)

    # Create a new RSP output object
    rsp_out = Rmf()

    rsp_out.NumberEnergyBins = int((high - low) / step)
    rsp_out.EnergyUnits = arf_in.EnergyUnits
    
    if args.noarea:
        rsp_out.AreaIncluded = False
        rsp_out.RMFUnits = ''
    else:
        rsp_out.AreaIncluded = True
        rsp_out.RMFUnits = arf_in.ARFUnits

    # Calculate new arrays
    rsp_out.LowEnergy = low + step * np.arange(rsp_out.NumberEnergyBins,dtype=float)
    rsp_out.HighEnergy = low + step * (np.arange(rsp_out.NumberEnergyBins,dtype=float)+1.0)
    EffArea = np.zeros(rsp_out.NumberEnergyBins, dtype=float)

    # Linear interpolation of Effective area
    for i in np.arange(rsp_out.NumberEnergyBins):
        e = (rsp_out.LowEnergy[i] + rsp_out.HighEnergy[i]) / 2.0
        EffArea[i] = np.interp(e, x, arf_in.EffArea)

    # Assume same binning for energy channels (Not optimal, but ok...)
    rsp_out.NumberChannels = rsp_out.NumberEnergyBins

    rsp_out.Channel = np.arange(rsp_out.NumberChannels) + 1
    rsp_out.FirstChannel = rsp_out.Channel[0]
    rsp_out.ChannelLowEnergy = rsp_out.LowEnergy
    rsp_out.ChannelHighEnergy = rsp_out.HighEnergy

    # We choose only one response group per energy (simple Gaussian response)
    rsp_out.NumberGroups = np.ones(rsp_out.NumberChannels)

    # Set response thresholds
    rsp_out.ResponseThreshold = 1.E-7

    # Determine maximum width in channels for group
    # We assume that 10 times the FWHM at 1 keV is enough

    nbin_group = math.ceil(10 * args.resolution * 1E-3 / step)

    try:
        rsp_out.Matrix = np.zeros(nbin_group * rsp_out.NumberEnergyBins, dtype=float)
    except MemoryError:
        message.error('Not enough memory to create matrix...')

    rsp_out.NumberChannelsGroup = np.zeros(rsp_out.NumberEnergyBins, dtype=int)
    rsp_out.FirstChannelGroup = np.zeros(rsp_out.NumberEnergyBins, dtype=int)

    # Generate response matrix
    print("Number of energy bins: {0}".format(rsp_out.NumberEnergyBins))
    print("Number of channels per group: {0}".format(nbin_group))

    message.proc_start('Calculate response matrix')

    r = 0
    for i in np.arange(rsp_out.NumberEnergyBins):
        # The number of channels per group is nbin_group
        rsp_out.NumberChannelsGroup[i] = nbin_group

        # Calculate the start channel of the group
        rsp_out.FirstChannelGroup[i] = int(1 + i - math.ceil(nbin_group / 2))
        if rsp_out.FirstChannelGroup[i] < 1:
            rsp_out.FirstChannelGroup[i] = 1
        if rsp_out.FirstChannelGroup[i] > rsp_out.NumberChannels - nbin_group:
            rsp_out.FirstChannelGroup[i] = int(rsp_out.NumberChannels - nbin_group)

        # Fill the Matrix
        mu = (rsp_out.ChannelLowEnergy[i] + rsp_out.ChannelHighEnergy[i]) / 2.0
        sum = 0
        for j in np.arange(nbin_group):
            k = rsp_out.FirstChannelGroup[i] + j
            e = (rsp_out.ChannelLowEnergy[k - 1] + rsp_out.ChannelHighEnergy[k - 1]) / 2.0
            resp = gaussrsp(e, mu, args.resolution, args.resgradient)
            if resp < 0.:
                message.error('Negative response value detected. Quitting program...')
                sys.exit(0)
            rsp_out.Matrix[r] = resp * step
            if (rsp_out.Matrix[r] < rsp_out.ResponseThreshold):
                rsp_out.Matrix[r] = 0.
            if not args.noarea:
                rsp_out.Matrix[r] = rsp_out.Matrix[r] * EffArea[i]
            sum = sum + rsp_out.Matrix[r]
            r = r + 1

    rsp_out.NumberTotalGroups = rsp_out.NumberEnergyBins
    rsp_out.NumberTotalElements = r
    message.proc_end(0)

    if args.noarea:
        if os.path.splitext(args.rspfile)[1] == '.rsp':
            message.warning("You are creating a .rsp file without area. Rename output file extension to .rmf.")

    # Check the created matrix
    message.proc_start('Check the created RSP/RMF matrix')
    stat = rsp_out.check()
    message.proc_end(stat)

    # Write the new matrix to file
    message.proc_start('Write RSP/RMF to file')
    stat = rsp_out.write(args.rspfile,overwrite=args.overwrite)
    message.proc_end(stat)


def parserange(string):
    # Split the input range string into two float values
    values = string.split(':')
    try:
        range = np.array([float(values[0]), float(values[1])])
    except ValueError:
        message.error('Invalid range format. Please input a range separated by a colon. For example: --range 0.1:10')
        sys.exit()
    return range


def genrsp_arguments():
    # Obtain command line arguments
    parser = argparse.ArgumentParser(description=message.docs)
    parser.add_argument('--arffile', help='Input Effective area file (.arf, required).', type=str, required=True)
    parser.add_argument('--resolution', help='Spectral resolution at 1 keV (FWHM, eV, required).', type=float,
                        required=True)
    parser.add_argument('--resgradient', help='Energy dependence of spectral resolution (in eV per keV, see documentation).',
                        type=float, default=0.0)
    parser.add_argument('--range', help='Energy range of the response files (in keV), for example --range 0.1:10.',
                        type=parserange, default='0.1:10')
    parser.add_argument('--sampling', help='The number of channels per resolution element (FWHM) at 1 keV.',
                        type=int, default=5)
    parser.add_argument('--rspfile', help='Output filename for the OGIP RSP file (.rsp required).', type=str,
                        required=True)
    parser.add_argument('--no-area', help="Do NOT include effective area in response file (create .rmf).", dest="noarea", 
                        action="store_true", default=False)
    parser.add_argument('--overwrite', help="Overwrite existing rsp files with same name.", action="store_true",
                        default=False)
    parser.add_argument('--no-color', help="Suppress color output.", dest="color", action="store_false", default=True)
    parser.add_argument('--version', action='version', version=message.version)

    return parser


if __name__ == "__main__":
    main()
