#!/usr/bin/env python

"""

    Pipeline for reducing SALT RSS MOS data, using the stuff that comes in the product/ dir.

    Copyright 2014-2020 Matt Hilton (matt.hilton@mykolab.com)
    
    This file is part of RSSMOSPipeline.

    RSSMOSPipeline is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    RSSMOSPipeline is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with RSSMOSPipeline.  If not, see <http://www.gnu.org/licenses/>.
    
"""

import os
import sys
import glob
import time
import datetime
import argparse
import numpy as np
import logging
from RSSMOSPipeline import RSSMOSTools
#import pickle
import IPython
#plt.matplotlib.interactive(True)
    
#-------------------------------------------------------------------------------------------------------------
if __name__ == '__main__':

    parser = argparse.ArgumentParser("rss_mos_reducer")
    parser.add_argument("rawDir", help="The directory that holds the partially-processed data. Usually called product.")
    parser.add_argument("reducedDir", help="The directory where the reduced data will be written, along with any diagnostic data. This directory will be created if it doesn't already exist.")
    parser.add_argument("maskName", help="Use 'all' to reduce all data found under rawDir/, or 'list' to list all masks (by object) found under rawDir/. The maskName is made from the keyword combination OBJECT_MASKID found in the .fits headers.") 
    parser.add_argument("-t", "--threshold", dest="threshold", default=0.1, help="Threshold used for the MOS slit finding algorithm - check that all slits are being found using e.g. masterflat_0.fits and the .reg files produced by the pipeline. Values in the range 0.1-0.4 work best (default=0.1; increase this value if some slits are getting divided; decrease if slits missing). This option only applies to MOS masks.")
    parser.add_argument("-T", "--longslit-threshold", dest="longslitThreshold", default=2.0, help="Threshold used for the longslit pseudo-slit finding algorithm, in sigma (i.e., for detecting object traces).")
    parser.add_argument("-i", "--iterative-extraction", dest="iterativeMethod", action='store_true', help = "Use the iterative spectral extraction method.")
    parser.add_argument("-f", "--skysub-fraction", dest="subFrac", default=0.8, help="Fraction of the sky background to be subtracted in each iteration of the iterative spectral extraction algorithm (default=0.8; increase this value for faster convergence). This only has an effect if the -i switch is also used.")
    parser.add_argument("-e", "--exclude-masks", dest="excludeMasks", default="", help="Names of masks to exclude (if using maskName = 'all'). Separate mask names with , but no spaces (e.g., -e M1,M2). Useful for avoiding inclusion of e.g., standard stars.")
    parser.add_argument("-s", "--slits", dest="extensionsList", default="all", help="Reduce the data corresponding to the given slit names only. Separate slit names with , but no spaces (e.g., -s SLIT9,SLIT15).")
    parser.add_argument("-S", "--skip-done", dest="skipDone", default=False, action='store_true', help="Skip previously processed masks, for which output already exists.")

    args = parser.parse_args()

    rawDir=args.rawDir
    baseOutDir=args.reducedDir
    maskName=args.maskName
    
    threshold=float(args.threshold)
    longslitThreshold=float(args.longslitThreshold)
    subFrac=float(args.subFrac)
    iterativeMethod=args.iterativeMethod
    excludeMasks=args.excludeMasks.split(",")
    extensionsList=args.extensionsList
    if extensionsList != "all":
        extensionsList=extensionsList.split(",")
    
    if os.path.exists(baseOutDir) == False:
        os.makedirs(baseOutDir)

    # Logging
    logger=logging.getLogger('RSSMOSPipeline')
    logger.setLevel(logging.DEBUG)
    formatter=logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    # ... to terminal
    handler=logging.StreamHandler(sys.stdout)
    handler.setLevel(logging.DEBUG)
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    # ... and to file
    logFileName='rss_mos_reducer-%s.log' % (datetime.datetime.now().isoformat())    
    fh=logging.FileHandler(logFileName)
    fh.setLevel(logging.DEBUG)
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    logger.info("started: %s" % (datetime.datetime.now().isoformat()))
    logger.info("parameters: %s" % str(args))
            
    # Sort out what's what...
    infoDict=RSSMOSTools.getImageInfo(rawDir)
    
    if maskName == 'list':
        maskList=list(infoDict.keys())
        maskList.sort()
        print("Found %d masks:" % (len(maskList)))
        for key in maskList:
            infoStr=""
            for maskKey in infoDict[key].keys():
                if type(infoDict[key][maskKey]) == dict and len(infoDict[key][maskKey]['modelExists']) > 0:
                    infoStr=infoStr+str(infoDict[key][maskKey]['modelExists'][0])
                    infoStr=infoStr+"\t"+str(infoDict[key][maskKey]['modelFileNames'][0])
            print("    %s\t%s" % (key, infoStr))
        sys.exit()
    elif maskName != 'all':
        shortDict={}
        for key in infoDict.keys():
            if key == maskName:
                shortDict[key]=infoDict[key]
        infoDict=shortDict
        
    if maskName != 'all' and maskName not in infoDict.keys():
        print("ERROR: maskName not found. Try using 'list' to see available maskNames.")
        sys.exit()

    masksToReduce=[]
    for key in infoDict.keys():
        if key not in excludeMasks:
            addMask=False
            for maskKey in infoDict[key].keys():
                if type(infoDict[key][maskKey]) == dict and len(infoDict[key][maskKey]['modelExists']) > 0:
                    if np.all(infoDict[key][maskKey]['modelExists']) == True:
                        addMask=True
                    break
            if addMask == True:
                masksToReduce.append(key)
            else:
                logger.warning("Skipping mask '%s' as no wavelength calibration reference model currently exists." % (key))
                if type(infoDict[key][maskKey]) == dict and 'ARC' in infoDict[key][maskKey].keys():
                    logger.info("    Use rss_mos_create_arc_model to make a reference model and then re-run (after deleting %s)." % (rawDir+os.path.sep+"imageInfo.pkl"))
                    logger.info("    arc files list: %s" % (infoDict[key][maskKey]['ARC']))
            
    # We're organised by object name, reduce each in turn
    for maskName in masksToReduce:
        
        logger.info("Mask: %s" % (maskName))
        
        outDir=baseOutDir+os.path.sep+maskName
        if os.path.exists(outDir) == False:
            os.makedirs(outDir)
        
        testOutFiles=glob.glob(outDir+os.path.sep+'1DSpec_2DSpec_stackAndExtract'+os.path.sep+'*.fits')
        if len(testOutFiles) > 0 and args.skipDone == True:
            logger.info("skipping mask '%s' - found previously made output files" % (maskName))
            continue        
        
        # Tied ourselves in knots a bit here...
        maskDict=infoDict[maskName][infoDict[maskName]['maskID']]
        maskDict['maskID']=infoDict[maskName]['maskID']
        maskDict['objName']=infoDict[maskName]['objName']
        maskType=infoDict[maskName]['maskType']
        
        RSSMOSTools.makeMasterFlats(maskDict, outDir)

        if maskType == 'MOS':
            RSSMOSTools.cutIntoSlitLets(maskDict, outDir, threshold = threshold)
        elif maskType == 'LONGSLIT':
            RSSMOSTools.cutIntoPseudoSlitLets(maskDict, outDir, thresholdSigma = longslitThreshold)
        
        try:
            RSSMOSTools.applyFlatField(maskDict, outDir)
        except:
            logger.warning("flat fielding failed on mask '%s' - try different threshold or longslitThreshold?" % (maskName))
            continue
        
        RSSMOSTools.wavelengthCalibration2d(maskDict, outDir, extensionsList = extensionsList)

        RSSMOSTools.extractAndStackSpectra(maskDict, outDir, extensionsList = extensionsList, 
                                           iterativeMethod = iterativeMethod, subFrac = subFrac)
 
