#!python

import argparse
from argparse import RawTextHelpFormatter
import sys
from os.path import abspath, expanduser
import os
import warnings
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
sys.path.insert(0, os.path.abspath('..'))
from specdal.collection import Collection, proximal_join
import shutil

parser = argparse.ArgumentParser(description='SpecDAL Pipeline',
                                 formatter_class=RawTextHelpFormatter)
# io options
parser.add_argument('input_dir', metavar='INPUT_PATH', action='store',
                    help='directory containing input files')
parser.add_argument('--proximal_reference', default=None, metavar='PATH',
                    action='store',
                    help='directory containing proximal reference spectral files')
parser.add_argument('-o', '--output_dir', metavar='PATH',
                    default='./specdal_output', action='store',
                    help='directory to store the csv files and figures')
parser.add_argument('-op', '--prefix', metavar='PREFIX',
                    type=str, action='store', default='dataset',
                    help='option to specify prefix for output dataset files')
parser.add_argument('-of', '--omit_figures', action='store_true',
                    help='option to omit output png figures')
parser.add_argument('-od', '--omit_data', action='store_true',
                    help='option to omit output csv files')
parser.add_argument('-oi', '--omit_individual', action='store_true',
                    help='option to omit output of individual csv file for each spectrum file')
# interpolation
parser.add_argument('-i', '--interpolate', default=None,
                    choices=['slinear', 'cubic'],
                    help='specify the interpolation method.\n'
                    'method descriptions can be found on scipy docs:\n'
                    'https://docs.scipy.org/doc/scipy-0.19.1/reference/generated/scipy.interpolate.interp1d.html')
parser.add_argument('-is', '--interpolate_spacing', metavar='SPC',
                    action="store", type=int, default=1,
                    help='specify desired spacing for interpolation in nanometers\n')
## overlap stitcher
parser.add_argument('-s', '--stitch', default=None,
                    choices=['mean', 'median', 'min', 'max'],
                    help='specify overlap stitching method;\n'
                    'not necessary if data at detector edges does not overlap')
# jump corrector
parser.add_argument('-j', '--jump_correct', default=None,
                    choices=['additive'],
                    help='specify jump correction method;')
parser.add_argument('-js', '--jump_correct_splices', metavar='WVL',
                    default=[1000, 1800], type=int, nargs='+',
                    help='wavelengths of jump locations')
parser.add_argument('-jr', '--jump_correct_reference', metavar='REF',
                    type=int, action='store', default=0,
                    help='specify the reference detector '
                    '(e.g. VNIR is 1, SWIR1 is 2)')
# groupby
parser.add_argument('-g', '--group_by', action='store_true',
                    help='create groups using filenames')
parser.add_argument('-gs', '--group_by_separator',  type=str,
                    metavar='S', default='_',
                    help='specify filename separator character to define groups')
parser.add_argument('-gi', '--group_by_indices', metavar='I', nargs='*', type=int,
                    help='specify the indices of the split filenames to define a group')
parser.add_argument('-gmean', '--group_mean', dest='aggr', action='append_const',
                    default=[],
                    const='mean', help='calculate group means and append to group figures')
parser.add_argument('-gmedian', '--group_median', dest='aggr', action='append_const',
                    const='median', help='calculate group median and append to group figures')
parser.add_argument('-gstd', '--group_std', dest='aggr', action='append_const',
                    const='std', help='calculate group standard deviation and append to group figures')
# misc
parser.add_argument('-q', '--quiet', default=False, action='store_true')
parser.add_argument('-f', '--force', default=False, action='store_true',
                    help='if output path exists, remove previous output and run')

args = parser.parse_args()

################################################################################
# main
################################################################################
VERBOSE = not args.quiet

def print_if_verbose(*args, **kwargs):
    if VERBOSE:
        print(*args, **kwargs)
    
indir = abspath(expanduser(args.input_dir))
outdir = abspath(expanduser(args.output_dir))
datadir = os.path.join(outdir, 'data')
figdir = os.path.join(outdir, 'figures')

if not os.path.exists(indir):
    raise FileNotFoundError("path " + indir + " does not exist")

if os.path.exists(outdir):
    while not args.force:
        # prompt user for action
        ans = input(outdir + ' already exists. Are you sure you want to remove its contents? [y/n]: ')
        ans = ans.strip().lower()
        if ans == 'y':
            args.force = True
        elif ans == 'n':
            print('exiting pipeline...')
            sys.exit(0)
    print('removing {}'.format(outdir))
    shutil.rmtree(outdir)

# make output directories
for d in (outdir, datadir, figdir):
    os.makedirs(d, exist_ok=True)

c = Collection(name=args.prefix)
print_if_verbose('Reading target measurements from ' + indir)
c.read(directory=indir)

if args.proximal_reference:
    print_if_verbose('Reading base measurements from ' + args.proximal_reference)
    c_base = Collection(name=args.prefix + '_base')
    c_base.read(directory=args.proximal_reference)

if args.interpolate:
    print_if_verbose('interpolating...')
    c.interpolate(spacing=args.interpolate_spacing, method=args.interpolate)
    if args.proximal_reference:
        c_base.interpolate(spacing=args.interpolate_spacing, method=args.interpolate)
if args.stitch:
    print_if_verbose('Stitching...')
    c.stitch(method=args.stitch)
    if args.proximal_reference:
        c_base.stitch(method=args.stitch)
if args.jump_correct:
    print_if_verbose('Jump correcting...')
    c.jump_correct(splices=args.jump_correct_splices,
                   reference=args.jump_correct_reference,
                   method=args.jump_correct)
    if args.proximal_reference:
        c_base.jump_correct(splices=args.jump_correct_splices,
                            reference=args.jump_correct_reference,
                            method=args.jump_correct)

if args.proximal_reference:
    print_if_verbose('Joining proximal data...')
    c = proximal_join(c_base, c, on='gps_time_tgt', direction='nearest')

# group by
groups = None
if args.group_by:
    print_if_verbose('Grouping...')
    groups = c.groupby(separator=args.group_by_separator,
                       indices=args.group_by_indices)

# output individual spectra
if not args.omit_individual:
    print_if_verbose('Saving individual spectrum outputs...')
    indiv_datadir = os.path.join(datadir, 'indiv')
    indiv_figdir = os.path.join(figdir, 'indiv')
    os.mkdir(indiv_datadir)
    os.mkdir(indiv_figdir)
    for spectrum in c.spectra:
        spectrum.to_csv(os.path.join(indiv_datadir, spectrum.name + '.csv'))
        spectrum.plot(legend=False)
        plt.savefig(os.path.join(indiv_figdir, spectrum.name + '.png'), bbox_inches='tight')
        plt.close()

# output whole and group data
if not args.omit_data:
    print_if_verbose('Saving entire and grouped data outputs...')
    c.to_csv(os.path.join(datadir, c.name + ".csv"))
    if groups:
        for group_id, group_coll in groups.items():
            group_coll.to_csv(os.path.join(datadir, group_id + '.csv'))

# calculate group aggregates
if len(args.aggr) > 0:
    print_if_verbose('Calculating group aggregates...')
for aggr in args.aggr:
    aggr_coll = Collection(name=c.name+'_'+aggr,
                                 spectra=[getattr(group_coll, aggr)(append=True)
                                          for group_coll in groups.values()],
                                 measure_type=c.measure_type)
    # output
    print_if_verbose('Saving group {} outputs...'.format(aggr))
    aggr_coll.to_csv(os.path.join(datadir, aggr_coll.name + '.csv'))
    aggr_coll.plot(legend=False)
    plt.savefig(os.path.join(figdir, aggr_coll.name + '.png'), bbox_inches='tight')
    plt.close()

# output whole and group figures (possibly with aggregates appended)
if not args.omit_figures:
    print_if_verbose('Saving entire and grouped figure outputs...')
    c.plot(legend=False)
    plt.savefig(os.path.join(figdir, c.name + ".png"),  bbox_inches="tight")
    plt.close()
    if groups:
        for group_id, group_coll in groups.items():
            group_coll.plot(legend=False)
            plt.savefig(os.path.join(figdir, group_id + ".png"),  bbox_inches="tight")
            plt.close()
