#!/usr/bin/env python

""" This is a first prototype of MAUD script to filter a NetCDF file

    Far away from an ideal solution. It's just to resolve the Bia's
      problem for now.

    Type this command with -h to see the help. Here is an example:

    maud4nc -lowpasswindowlength=7 --highpasswindowlength=90 --var='temperature' -w hamming --npes=18 -o model_output.nc
"""

from os.path import isfile
from optparse import OptionParser
#import multiprocessing as mp

import numpy as np
from numpy import ma
from netCDF4 import Dataset

from maud import wmean_bandpass_1D
from maud import wmean_1D
#from cmaud import wmean_1D


# ==== Parsing the options on command line
parser = OptionParser()

parser.add_option("--lowpasswindowlength", dest="lowpasswindowlength",
    help="Length of the filter window. Must be on the same scale of the scalevar")

parser.add_option("--highpasswindowlength", dest="highpasswindowlength",
    help="Length of the filter window. Must be on the same scale of the scalevar")

parser.add_option("-o", dest="outputfile",
    help="The output file where the filtered data will be saved.")

parser.add_option("--scalevar", dest="scalevar",
    help="The scale on the dimension to be filtered.", default="time")

parser.add_option("--var", dest="var",
    help="Variable to be filtered")

parser.add_option("-w", dest="windowmethod",
    help="Type of window [hamming, hann, boxcar, triangle, lanczos]",
    default="hamming")


(options, args) = parser.parse_args()

if (options.lowpasswindowlength is not None):
    options.lowpasswindowlength = float(options.lowpasswindowlength)
if (options.highpasswindowlength is not None):
    options.highpasswindowlength = float(options.highpasswindowlength)

varname = options.var
if options.outputfile is None:
    options.varout  = options.var + "_maud"
    newfile = False
else:
    options.varout  = options.var
    newfile = True


if not isfile(args[0]):
    print("Seems like %s is not a file." % args[0])
    import sys; sys.exit()


#if var not in nc.variables.keys():
#    import sys; sys.exit()

# ====
if newfile == False:
    nc = Dataset(args[0], 'a')

    attributes = nc.variables[options.var].ncattrs()
    try:
        out = nc.createVariable(options.varout, 
                  nc.variables[options.var].dtype, 
                  nc.variables[options.var].dimensions, 
                  fill_value=nc.variables[options.var]._FillValue)
        attributes.remove('_FillValue')
    except:
        out = nc.createVariable(options.varout, 
                  nc.variables[options.var].dtype, 
                  nc.variables[options.var].dimensions)

if newfile == True:
    ncin = Dataset(args[0], 'r')
    ncout = Dataset(options.outputfile, 'w')

    # Global Attributes

    # Copying dimensions
    dims = ncin.variables[varname].dimensions
    for dim in dims:
        ncout.createDimension(dim, len(ncin.dimensions[dim]))
    
    # Copying variables related to the dimensions
    variables = {}

    for dim in dims:
        variables[dim] = ncout.createVariable(dim, 
                ncin.variables[dim].datatype.name,
                ncin.variables[dim].dimensions)
        variables[dim][:] = ncin.variables[dim][:]
        #attributes = nc.variables[dim].ncattrs()
        #for a in attributes:
        for a in ncin.variables[dim].ncattrs():    
            setattr(ncout.variables[dim], a, getattr(ncin.variables[dim], a))
    
    output = ncout.createVariable(varname,
            ncin.variables[varname].dtype.name,
            ncin.variables[varname].dimensions,
            fill_value=ncin.variables[varname]._FillValue)

    # If there is a scale factor, I would need to save the output
    #   considering that.
    attrs = ncin.variables[varname].ncattrs()
    attrs = [a for a in attrs if a != '_FillValue']

    axis = ncin.variables[varname].dimensions.index(
            ncin.variables[options.scalevar].dimensions[0])
    assert axis==0

    #I, J, K = ncin.variables[varname].shape[1:]
    N = ncin.variables[varname].shape[1]
    try:
        from progressbar import ProgressBar
        pbar = ProgressBar(maxval=I*J).start()
    except:
        pass

    data = ncin.variables[varname]

    scale = ncin.variables[options.scalevar][:].astype(np.float)

    #npes = 2*mp.cpu_count()
    #qout = mp.Queue(2*npes)
    #pool = mp.Pool(npes)

    try:
        add_offset = float(getattr(ncin.variables[varname], 'add_offset'))
    except:
        add_offset = None

    try:
        scale_factor = float(getattr(ncin.variables[varname], 'scale_factor'))
    except:
        scale_factor = None

    for n in range(N):
        print "n: %s/%s " % (n, N)
        try:
            pbar.update((i+1)*(j+1))
        except:
            pass

        if (options.highpasswindowlength is not None) and \
                (options.lowpasswindowlength is not None):
                    tmp = wmean_bandpass_1D(data[:, n],
                            lshorterpass = options.highpasswindowlength,
                            llongerpass = options.lowpasswindowlength,
                            t = scale,
                            method = options.windowmethod,
                            axis=0)

        elif (options.lowpasswindowlength is not None):
            tmp = wmean_1D(data[:, n],
                    l=options.lowpasswindowlength,
                    t=scale,
                    method=options.windowmethod,
                    axis=0)

        elif (options.highpasswindowlength is not None):
            tmp = data[:, n] - \
                    wmean_1D(data[:, n],
                            l=options.highpasswindowlength,
                            t=scale,
                            method=options.windowmethod,
                            axis=0)

        else:
            import sys; sys.exit()

        if add_offset is not None:
            tmp = tmp - add_offset
        if scale_factor is not None:
            tmp = tmp / scale_factor

        if ncout.variables[varname].dtype in ['int16']:
            ncout.variables[varname][:, n] = tmp.round()
        else:
            ncout.variables[varname][:, n] = tmp

    for a in attrs:
        setattr(ncout.variables[varname], a,
                getattr(ncin.variables[varname], a))

    ncin.close()
    ncout.sync()
    ncout.close()

    import sys; sys.exit()


axis = nc.variables[options.var].dimensions.index(nc.variables[options.scalevar].dimensions[0])

assert axis==0
assert len(nc.variables[options.var].dimensions)==3
    
I, J = nc.variables[options.var].shape[1:]
try:
    from progressbar import ProgressBar
    pbar = ProgressBar(maxval=I*J).start()
except:
    pass

data = nc.variables[options.var]

for i in range(I):
    for j in range(J):
        try:
            pbar.update((i+1)*(j+1))
        except:
            pass
        #tmp = wmean_1D_serial(data=ma.array(data[:, i, j]), 
        #                t = nc.variables[options.scalevar][:],
        #                l=options.lowpasswindowlength, axis=axis)

        #out[:, i, j] = tmp - wmean_1D_serial(data = tmp, 
        #                t = nc.variables[options.scalevar][:],
        #                l = options.highpasswindowlength, 
        #                axis = axis)

        out[:, i, j] = wmean_bandpass_1D_serial(data[:, i, j],
                lshorterpass = options.highpasswindowlength, 
                llongerpass = options.lowpasswindowlength, 
                t = nc.variables[options.scalevar][:], 
                method = options.windowmethod, 
                axis=0)

nc.close()
