#!python -Wignore
"""StarbugII Matching 
usage: starbug2-match [-BGfhv] [-o output] [-p file.param] [-s KEY=VAL] table.fits ...
    -B  --band               : match in "BAND" mode (does not preserve a column for every frame)
    -D  --dither             : match in "DITHER" mode (preserves a column for every frame)
    -f  --full               : export full catalogue
    -G  --generic            : match in "GENERIC" mode
    -h  --help               : show help message
    -o  --output  file.fits  : output matched catalogue
    -p  --param   file.param : load starbug parameter file
    -s  --set     option     : set value in parameter file at runtime (-s MATCH_THRESH=1)
"""

import os,sys,getopt,glob
import numpy as np
import pkg_resources
from astropy.io import fits
from astropy.table import Table, hstack, vstack
import starbug2
from starbug2 import utils
from starbug2 import matching

VERBOSE=0x01
KILLPROC=0x02

BANDMATCH   =0x04
DITHERMATCH =0x08
GENERICMATCH=0x10

EXPFULL = 0x100

options=0

parameter={}
pfile=None
output=None
setopt={}

def usage():
    if options& VERBOSE: quit(__doc__)
    else: quit( __doc__.split("\n")[1])

opts,args=getopt.getopt(sys.argv[1:], "BDfGhvo:p:s:", ("band","dither","full", "generic", "help","verbose",
                                                "output=", "param=", "set="))

for opt,optarg in opts:
    if opt in ("-h", "--help"): usage()
    if opt in ("-v", "--verbose"): options|=VERBOSE
    if opt in ("-o", "--output"): output=optarg
    if opt in ("-p", "--param"): pfile=optarg

    if opt in ("-f","--full"): options|=EXPFULL


    if opt in ("-s","--set"): 
        if '=' in optarg:
            key,val=optarg.split('=')
            try: val=float(val)
            except: pass
            setopt[key]=val

        else: perror("unable to set parameter, use syntax -s KEY=VALUE\n")

    if opt in ("-B","--band"): options|=BANDMATCH
    if opt in ("-D","--dither"): options|=DITHERMATCH
    if opt in ("-G","--generic"): options|=GENERICMATCH

if not len(args): usage()

if pfile: parameters=utils.load_params(pfile)
elif os.path.exists("./starbug.param"): parameters=utils.load_params("./starbug.param")
else: parameters=utils.load_params("%s/default.param"%pkg_resources.resource_filename("starbug2", "param/"))

if parameters: parameters.update(setopt)
else: 
    utils.perror("failed to load parameter file\n")
    quit("..quitting :(")

tables=[Table.read(fname,format="fits") for fname in args]

colnames=starbug2.match_cols
colnames+=[ name for name in parameters["MATCH_COLS"].split() if name not in colnames]
dthreshold=parameters["MATCH_THRESH"]
nthreshold=parameters["NEXP_THRESH"]
rmmatch=parameters["RM_MATCH"]

if options & BANDMATCH:
    filters=[]
    tomatch={ starbug2.NIRCAM:[], starbug2.MIRI:[] }
    _colnames=["RA","DEC"]
    for i,tab in enumerate(tables):
        if not ( (fltr:=tab.meta.get("FILTER")) and fltr in list(starbug2.filters.keys())):
            if not (fltr:= set(tab.colnames) & set(starbug2.filters.keys()) ):
                perror("Unable to determine FILTER for \"%s\"\n"%args[i])
                continue
        filters.append(fltr)
        tomatch[starbug2.filters[fltr].instr].append(tab)
        _colnames+=([fltr,"e%s"%fltr])
    
    if tomatch[starbug2.NIRCAM] and tomatch[starbug2.MIRI]:
        utils.printf("Detected NIRCam to MIRI matching\n")
        nircam_matched=matching.band_match(tomatch[starbug2.NIRCAM], colnames=_colnames)
        puts()
        miri_matched=matching.band_match(tomatch[starbug2.MIRI], colnames=_colnames)
        puts()

        if not (lockcol:=parameters.get("LOCKCOL")):
            lockcol= sorted([f for f in filters if starbug2.filters[f].instr==starbug2.NIRCAM],key=lambda f: list(starbug2.filters.keys()).index(f))[-1]
        mask= np.isnan(nircam_matched[lockcol])
        load=utils.loading(len(miri_matched), msg="Combining NIRCAM-MIRI(%.2g\")"%dthreshold)
        matched,_=matching.generic_match((nircam_matched[~mask],miri_matched), threshold=dthreshold, add_src=True, load=load)
        matched.remove_column("NUM")
        matched=vstack((matched, nircam_matched[mask]))
    else:
        matched=matching.band_match(tables, colnames=_colnames)
        
    fname=output if output else "out.fits"
    utils.export_table(matched,fname=fname)

else:
    if options & DITHERMATCH: av,full=matching.dither_match(tables, threshold=dthreshold, colnames=colnames)
    if options & GENERICMATCH: 
        options|=EXPFULL
        av,full=matching.generic_match(tables,threshold=dthreshold, add_src=True)
        #av=None#_=matching.finish_matching(full, colnames=tables[0].colnames)
    else:
        av,full=matching.cascade_match(tables, threshold=dthreshold, colnames=colnames)

    dtypes=[]
    for name in full.colnames:
        if name=="Catalogue_Number": dtypes.append(str)
        elif name=="flag": dtypes.append(np.uint16)
        else: dtypes.append(float)
    full=Table(full,dtype=dtypes).filled(np.nan) ## fill empty values with null

    if av: 
        av.meta.update(tables[0].meta)
        if nthreshold!=-1:
            mask=av["NUM"]>=nthreshold
            av=av[mask]
        if rmmatch!=-1:
            warn()
            perror("RM_MATCH has been removed, use NEXP_THRESH instead\n")

    if output is None:
        output=utils.combine_fnames( [ name for name in args] , ntrys=100)
    dname,fname,ext=utils.split_fname(output)

    utils.printf("-> %s/%s*\n"%(dname,fname))
    if options&EXPFULL: utils.export_table(full,fname="%s/%sfull.fits"%(dname,fname))
    if av: utils.export_table(av,"%s/%smatch.fits"%(dname,fname))
