#!/usr/bin/env python

"""

nemoSpec: given a config file and a catalog of positions, extract average flux in apertures.

"""

import sys
import os
import datetime
from nemo import *
import nemo
from nemo import MockSurvey
import argparse
import astropy
import astropy.table as atpy
import astropy.io.fits as pyfits
from astLib import astWCS
import numpy as np
import pylab
import pickle
import types
import yaml
plotSettings.update_rcParams()

#------------------------------------------------------------------------------------------------------------
def makeParser():
    
    parser=argparse.ArgumentParser("nemoSpec")
    parser.add_argument("configFileName", help="""A .yml configuration file.""")
    parser.add_argument("catFileName", help = "Catalog file name (.fits format).\
                        The catalog must contain at least the following columns: name, RADeg, decDeg.")
    parser.add_argument("-o", "--output", dest="outFileName", help = "Output catalog file name\
                        (.fits format). If not given, the name of the output catalog file will be based on\
                        catFileName.", default = None)
    parser.add_argument("-m", "--method", dest = "method", type = str, default = "CAP",
                        help = "Method used for extracting the spectrum. Options are: (i) 'CAP', which\
                        uses a compensated aperture photometry filter (use --radius-arcmin to set the\
                        aperture size); (ii) 'matchedFilter', which uses a simple single-frequency matched\
                        filter on each band, assuming a cluster signal template (this requires the input\
                        catalog to have been produced by nemo and contain template names in the format\
                        'Arnaud_M2e14_z0p4', for example, in order to choose an appropriate signal scale\
                        for each object in the catalog).")
    parser.add_argument("-r", "--radius-arcmin", dest = "diskRadiusArcmin", type = float, default = 4.0,
                        help = "CAP method only. Disk aperture radius in arcmin, within which the signal\
                        is measured. The background will be estimated in an annulus between\
                        diskRadiusArcmin < r < sqrt(2) * diskRadiusArcmin.")
    parser.add_argument("-w", "--write-maps", dest = "saveFilteredMaps", action = "store_true",
                        help = "matchedFilter method only: If set, saves the filtered maps in the\
                        nemoSpecCache directory (which is created in the current working directory, if it\
                        doesn't already exist).")
    parser.add_argument("-z", "--redshift-catalog", dest = "redshiftCatFileName", type = str, default = None,
                        help = "If given, this FITS table format catalog is cross matched (using 2.5\
                        arcmin radius) against the input catalog specified by catFileName. The analysis is\
                        then performed on the subset of objects with redshifts, and a redshift column is\
                        added to the output catalog. Note that redshift information is not used in the\
                        analysis; this feature is provided for convenience only.")
    parser.add_argument("-M", "--mpi", dest="MPIEnabled", action="store_true", help="Enable MPI. If you\
                        want to use this, run with e.g., mpiexec -np 4 nemoSpec configFile.yml -M",
                        default = False)
    parser.add_argument("-n", "--no-strict-errors", dest="noStrictMPIExceptions", action="store_true", 
                        help="Disable strict exception handling (applies under MPI only, i.e., must be\
                        used with the -M switch). If you use this option, you will get the full traceback\
                        when a Python Exception is triggered, but the code may not terminate. This is due\
                        to the Exception handling in mpi4py.", default = False)
    parser.add_argument("-v", "--version", action = 'version', version = '%(prog)s' + ' %s' % (nemo.__version__))

    return parser

#------------------------------------------------------------------------------------------------------------
if __name__ == '__main__':

    parser=makeParser()
    args=parser.parse_args()
    
    if args.noStrictMPIExceptions == True:
        strictMPIExceptions=False
    else:
        strictMPIExceptions=True
    
    parDictFileName=args.configFileName
    config=startUp.NemoConfig(parDictFileName, MPIEnabled = args.MPIEnabled, strictMPIExceptions = strictMPIExceptions)
    
    catFileName=args.catFileName
    outFileName=args.outFileName
    tab=atpy.Table().read(catFileName)
    
    if args.redshiftCatFileName is not None:
        print("... cross matching against redshift catalog using 2.5 arcmin radius ...")
        zTab=atpy.Table().read(args.redshiftCatFileName)
        if 'redshift' not in zTab.keys():
            raise Exception("No 'redshift' column found in '%s'" % (redshiftCatFileName))
        tab, zTab, rDeg=catalogs.crossMatch(tab, zTab, radiusArcmin = 2.5)
        tab['redshift']=zTab['redshift']
        if 'redshiftErr' in zTab.keys():
            tab['redshiftErr']=zTab['redshiftErr']
    
    if outFileName is None:
        outFileName=catFileName.replace(".fits", "_spec.fits")
    
    specTab=pipelines.extractSpec(config, tab, diskRadiusArcmin = args.diskRadiusArcmin, method = args.method,
                                  saveFilteredMaps = args.saveFilteredMaps)
    
    if config.MPIEnabled == True:
        config.comm.barrier()
        specTabList=config.comm.gather(specTab, root = 0)
        if config.rank == 0:
            print("... gathered catalogs ...")
            toStack=[]  # We sometimes return [] if no objects found - we can't vstack those
            for collectedTab in specTabList:
                if type(collectedTab) == astropy.table.table.Table:
                    toStack.append(collectedTab)
            specTab=atpy.vstack(toStack)
            # Strip out duplicates (this is necessary when run in tileDir mode under MPI)
            if len(specTab) > 0:
                specTab, numDuplicatesFound, names=catalogs.removeDuplicates(specTab)
    if config.rank == 0:
        specTab=catalogs.flagTileBoundarySplits(specTab)
        specTab.sort('name')
        specTab.meta['NEMOVER']=nemo.__version__
        specTab.write(outFileName, overwrite = True)  
