#!/usr/bin/env python

""" This is a MAUD script to filter a NetCDF file

    Far away from an ideal solution. It's just to resolve the Bia's
      problem for now.

    This script will use the latitude and longitude provided in the
      netcdf file itself, and apply a moving average window considering
      the linear distances.

    This script is mostly usefull for regular cartesian grids, and so
      it is not regular on linear distances, or for numerical model
      grids, that can change the spacing between the grids.

      Type this command with -h to see the help. Here is an example:

      maud4latlonnc -l 600e3 --var='temperature' -w hamming --interp --npes=18 model_output.nc

      This will save on the same netcdf another variable named temperature_maud
"""

from optparse import OptionParser

import numpy as np
from numpy import ma
from netCDF4 import Dataset

from maud import window_1Dbandpass
from cmaud import window_mean_2D_latlon


# ==== Parsing the options on command line
parser = OptionParser()

parser.add_option("-l",
    action="store", type="float", dest="l",
    help="The filter scale in meters. For example: -l 100e3")

parser.add_option("--var", dest="var",
    help="Variable to be filtered")

parser.add_option("-w", dest="windowmethod",
    help="Type of window [hamming, hann, boxcar, triangle, lanczos]",
    default="hamming")

parser.add_option("--interp",
    action="store_true", dest="interp",
    help="If selected fill the masked points if is there available data around.",
    default = False)

parser.add_option("--npes",
    action="store", type="int", dest="npes",
    help="Defines the number of parallel processes.",
    default=None)

(options, args) = parser.parse_args()

print "args: ", args
print "options: ", options
# ============================================================================
nc = Dataset(args[0], 'a')

#if var not in nc.variables.keys():
#    import sys; sys.exit()


attributes = nc.variables[options.var].ncattrs()
#nc.variables[options.var].missing_value
# Think well how to do it. Probably I should first delete the old one and create from scratch, so there is no risk of old attributes and other stuff left behind.
options.varout  = options.var + "_maud"
if (options.varout in nc.variables):
    print "Hey! %s is already in this file. I'll overwrite it" % \
        options.varout
    out = nc.variables[options.varout]
    if '_FillValue' in attributes:
        out[:] = nc.variables[options.var]._FillValue
        attributes.remove('_FillValue')
    elif 'missing_value' in attributes:
        out[:] = nc.variables[options.var].missing_value
else:
    try:
        out = nc.createVariable(options.varout, 
                  nc.variables[options.var].dtype, 
                  nc.variables[options.var].dimensions, 
                  fill_value=nc.variables[options.var]._FillValue)
        attributes.remove('_FillValue')
    except:
        out = nc.createVariable(options.varout, 
                  nc.variables[options.var].dtype, 
                  nc.variables[options.var].dimensions)

for a in attributes:
    setattr(out, a, getattr(nc.variables[options.var], a))


# ==== Handling Lat & Lon variables
# ---- First I'll guess the name of the variables
if ('latitude' in nc.variables) and ('longitude' in nc.variables):
    lat_var = 'latitude'
    lon_var = 'longitude'
elif ('lat' in nc.variables) and ('lon' in nc.variables):
    lat_var = 'lat'
    lon_var = 'lon'
else:
    print "Sorry, I couldn't guess the name of the lat lon variables."
lat = nc.variables[lat_var]
lon = nc.variables[lon_var]

# ---- If lat & lon are 1D variables, I'll need to do a meshgrid
if (len(lat.dimensions)==1) & (len(lon.dimensions)==1):
    if (lat.dimensions[0] == nc.variables[options.var].dimensions[1]) & \
            (lon.dimensions[0] == nc.variables[options.var].dimensions[2]):
                Lon, Lat = np.meshgrid(lon[:], lat[:])
    T, I, J = nc.variables[options.var].shape
else:
    import sys; sys.exit()

data = nc.variables[options.var]

try:
    from progressbar import ProgressBar
    pbar = ProgressBar(maxval=T).start()
except:
    print "ProgressBar is not available"

try:
    import multiprocessing as mp
    npesmax = 2 * mp.cpu_count() +1
    if options.npes > npesmax:
        print "Considering the numper of cpu on your machine, I'll stick with npe=%s"  % npesmax
        npes = npesmax
    else:
        npes = options.npes

    print " Will work with %s npes" % npes
    data_smooth = ma.empty(data.shape)
    pool = mp.Pool(npes)
    results = []

    print "Preparing the workers."
    for nt in range(T):
        results.append( pool.apply_async( window_mean_2D_latlon, (Lat, Lon, data[nt], options.l, options.windowmethod, options.interp) ) )

    pool.close()
    print "Collecting the results."
    for nt, r in enumerate(results):
        try:
            pbar.update(nt)
        except:
            pass

        out[nt] = data[nt] - r.get()

except:
    print "Sorry, didn't work to filter in parallel"
    raise

#for nt in range(T):
#        try:
#            pbar.update(nt)
#        except:
#            pass
#
#        tmp = window_mean_2D_latlon(Lat, Lon, data[nt], options.l, method = options.windowmethod) 
#        out[nt] = tmp - data[nt]

nc.close()
