#!/usr/bin/env python

"""

nemo driver script: for filtering maps and finding clusters

"""

import sys
#print("Running under python: %s" % (sys.version))
import os
import datetime, time
from nemo import *
import nemo
from nemo import MockSurvey
import argparse
import astropy
import astropy.table as atpy
import astropy.io.fits as pyfits
from astLib import astWCS
import numpy as np
import pylab
import pickle
import types
import yaml
#import IPython
pylab.matplotlib.interactive(False)

#------------------------------------------------------------------------------------------------------------
def makeParser():
    parser=argparse.ArgumentParser("nemo")
    parser.add_argument("configFileName", help="A .yml configuration file.")
    parser.add_argument("-S", "--calc-selection-function", dest="calcSelFn", action="store_true",
                        help="Calculate the completeness in terms of cluster mass, assuming the scaling\
                        relation parameters given in the .yml config file. Output will be written under the\
                        nemoOutput/selFn directory. This switch overrides the calcSelFn parameter in the\
                        .yml config file.", default = False)
    parser.add_argument("-I", "--run-source-injection-test", dest="sourceInjectionTest",
                        action="store_true", help="Run a source injection test, using the settings given\
                        in the .yml config file. Output will be written under the nemoOutput/diagnostics\
                        (raw data) and nemoOutput/selFn directories (position recovery model fits).\
                        This switch overrides the sourceInjectionTest parameter in the .yml config\
                        file.", default = False)
    parser.add_argument("-f", "--forced-photometry-catalog", dest="forcedCatalogFileName",
                        help = "If given, instead of detecting objects, perform forced photometry in the\
                        filtered maps at object locations given in the catalog. The catalog must contain at\
                        least the following columns: name, RADeg, decDeg.", default = None)
    parser.add_argument("-M", "--mpi", dest="MPIEnabled", action="store_true", help="Enable MPI. If you \
                        want to use this, run with e.g., mpiexec -np 4 nemo configFile.yml -M", 
                        default = False)
    parser.add_argument("-T", "--tiling-check", dest="tilingCheck", action="store_true",
                        help=" Runs until the tiling stage, then exits, providing info on the number of\
                        tiles used, and writing tile coordinates to the selFn/ directory.", default = False)
    parser.add_argument("-n", "--no-strict-errors", dest="noStrictMPIExceptions", action="store_true", 
                        help="Disable strict exception handling (applies under MPI only, i.e., must be\
                        used with the -M switch). If you use this option, you will get the full traceback\
                        when a Python Exception is triggered, but the code may not terminate. This is due\
                        to the Exception handling in mpi4py.", default = False)
    parser.add_argument("-v", "--version", action = 'version', version = '%(prog)s' + ' %s' % (nemo.__version__))

    return parser
    
#------------------------------------------------------------------------------------------------------------
if __name__ == '__main__':

    parser=makeParser()
    args=parser.parse_args()
    
    if args.noStrictMPIExceptions == True:
        strictMPIExceptions=False
    else:
        strictMPIExceptions=True
    
    parDictFileName=args.configFileName
    config=startUp.NemoConfig(parDictFileName, calcSelFn = args.calcSelFn,
                              sourceInjectionTest = args.sourceInjectionTest, MPIEnabled = args.MPIEnabled,
                              strictMPIExceptions = strictMPIExceptions, writeTileInfo = True)
    if args.tilingCheck == True:
        print(">>> Tiling check:")
        print("... This config has %d tiles." % (len(config.allTileNames)))
        sys.exit()

    config.parDict['forcedPhotometryCatalog']=args.forcedCatalogFileName
    if config.parDict['forcedPhotometryCatalog'] is not None:
        label=os.path.splitext(config.parDict['forcedPhotometryCatalog'])[0]
        label=os.path.basename(label)+"_"+os.path.basename(config.rootOutDir)+"_forcedCatalog"
        optimalCatalogFileName=label+".csv"
    else:
        optimalCatalogFileName=config.rootOutDir+os.path.sep+"%s_optimalCatalog.csv" % (os.path.split(config.rootOutDir)[-1])

    if os.path.exists(optimalCatalogFileName) == False:
        optimalCatalog=pipelines.filterMapsAndMakeCatalogs(config, writeAreaMask = True, writeFlagMask = True)
        if config.rank == 0:
            optimalCatalog=catalogs.flagTileBoundarySplits(optimalCatalog)
            optimalCatalog.sort('name')
            catalogs.writeCatalog(optimalCatalog, optimalCatalogFileName)
            catalogs.writeCatalog(optimalCatalog, optimalCatalogFileName.replace(".csv", ".fits"))
            addInfo=[{'key': 'SNR', 'fmt': '%.1f'}]
            catalogs.catalog2DS9(optimalCatalog, optimalCatalogFileName.replace(".csv", ".reg"),
                                 addInfo = addInfo, color = "cyan")
    else:
        if config.rank == 0: print("... already made catalog %s" % (optimalCatalogFileName))
                
    # Q function (filter mismatch) - if needed options have been given
    if 'photFilter' in config.parDict.keys() and config.parDict['photFilter'] is not None and config.parDict['fitQ'] == True:
        if os.path.exists(config.selFnDir+os.path.sep+"QFit.fits") == False:
            signals.fitQ(config)

    # Generates the noise versus area tables, and adds 'footprint_label' columns to output catalogs
    # config.comm.barrier() # Hopefully not needed
    pipelines.makeRMSTables(config)

    # Source injection test - one way to calculate completeness ; needed for position recovery tests
    sourceInjTable=None
    sourceInjPath=config.selFnDir+os.path.sep+"sourceInjectionData.fits"
    if os.path.exists(sourceInjPath) == False:
        if 'sourceInjectionTest' in config.parDict.keys() and config.parDict['sourceInjectionTest'] == True:
            if config.MPIEnabled == True:
                config.comm.barrier()   # Otherwise, some processes can begin before catalog written to disk and then crash
            sourceInjTable=maps.sourceInjectionTest(config)
    else:
        if config.rank == 0:
            print("... already made source injection data %s" % (sourceInjPath))

    # Tidying up etc.
    if config.rank == 0:
        print("... tidying up [time taken to here = %.3f sec]" % (time.time()-config._timeStarted))

        # Plot tile-averaged position recovery test
        if sourceInjTable is not None:
            sourceInjTable.meta['NEMOVER']=nemo.__version__
            sourceInjTable.write(config.selFnDir+os.path.sep+"sourceInjectionData.fits", overwrite = True)
            maps.positionRecoveryAnalysis(sourceInjTable, config.diagnosticsDir+os.path.sep+"positionRecovery.pdf",
                                          percentiles = [50, 95, 99.7], plotRawData = True,
                                          pickleFileName = config.diagnosticsDir+os.path.sep+'positionRecovery.pkl',
                                          selFnDir = config.selFnDir)

        ## Plot contamination together
        #if conTabDict != {}:
            #maps.plotContamination(conTabDict, config.diagnosticsDir)

        # Cache file containing weights for relativistic corrections
        # Saves doing this later (e.g., when nemoMass or nemoSelFn run) and it's quick to do
        signals.getFRelWeights(config)

        # Tidy up by making MEF files and deleting the (potentially 100s) of per-tile files made
        completeness.tidyUp(config)

        # Now do all completeness calculations, output etc. using a SelFn object
        if 'calcSelFn' in list(config.parDict.keys()) and config.parDict['calcSelFn'] == True:
            completeness.completenessByFootprint(config)
            if 'massLimitMaps' in config.parDict['selFnOptions'].keys():
                completeness.makeMassLimitMapsAndPlots(config)

        print(">>> Finished [time taken = %.3f sec]" % (time.time()-config._timeStarted))

