#!/usr/bin/env python

""" This is a MAUD script to filter a NetCDF file

    Far away from an ideal solution. It's just to resolve the Bia's
      problem for now.

    This script will use the latitude and longitude provided in the
      netcdf file itself, and apply a moving average window considering
      the linear distances.

    This script is mostly usefull for regular cartesian grids, and so
      it is not regular on linear distances, or for numerical model
      grids, that can change the spacing between the grids.

      Type this command with -h to see the help. Here is an example:

      maud4latlonnc --shorterpasslength 600e3 --var='temperature' -w hamming --interp --npes=18 --output=temperature_smooth_output.nc model_output.nc

      This will save on the same netcdf another variable named temperature_maud
"""

from os.path import isfile
import pkg_resources
from datetime import datetime
from optparse import OptionParser

import numpy as np
from numpy import ma
from netCDF4 import Dataset

from cmaud import wmean_2D_latlon_serial as wmean_2D_latlon


# ==== Parsing the options on command line
parser = OptionParser()

parser.add_option("-l",
    action="store", type="float", dest="l",
    help="The filter scale in meters. For example: -l 100e3")

parser.add_option("--largerpasslength",
        type="float", dest="largerpasslength", default=None,
        help="Length of the filter window in meters.")

parser.add_option("--shorterpasslength",
        type="float", dest="shorterpasslength", default=None,
        help="Length of the filter window in meters.")

parser.add_option("--var", dest="var",
    help="Variable to be filtered")

parser.add_option("-w", dest="windowmethod",
    help="Type of window [hamming, hann, boxcar, triangle, lanczos]",
    default="hamming")

parser.add_option("--interp",
    action="store_true", dest="interp",
    help="If selected fill the masked points if is there available data around.",
    default = False)

parser.add_option("--npes",
    action="store", type="int", dest="npes",
    help="Defines the number of parallel processes.",
    default=None)

parser.add_option("--output", dest="outputfile",
            help="The output file where the filtered data will be saved.")


(options, args) = parser.parse_args()

# ============================================================================
if ((options.largerpasslength is None) and
        (options.shorterpasslength is not None)):
    filtertype = 'highpass'
elif ((options.largerpasslength is not None) and
        (options.shorterpasslength is None)):
    filtertype = 'lowpass'
elif ((options.largerpasslength is not None) and
        (options.shorterpasslength is not None)):
    filtertype = 'bandpass'
else:
    import sys; sys.exit()

# ============================================================================
assert isfile(args[0])

varname = options.var
if options.outputfile is None:
    options.varout  = options.var + "_maud"
    newfile = False
else:
    options.varout  = options.var
    newfile = True

ver_maud = pkg_resources.get_distribution("maud").version
# ============================================================================
# ==== Handling Lat & Lon variables
# ---- First I'll guess the name of the variables
ncin = Dataset(args[0], 'r')

if ('latitude' in ncin.variables) and ('longitude' in ncin.variables):
    lat_var = 'latitude'
    lon_var = 'longitude'
elif ('lat' in ncin.variables) and ('lon' in ncin.variables):
    lat_var = 'lat'
    lon_var = 'lon'
#elif ('Lat' in ncin.variables) and ('Lon' in ncin.variables):
#    lat_var = 'Lat'
#    lon_var = 'Lon'
else:
    assert False, "Sorry, I couldn't guess the name of the lat lon variables."

lat = ncin.variables[lat_var]
lon = ncin.variables[lon_var]


# ---- If lat & lon are 1D variables, I'll need to do a meshgrid
if (len(lat.dimensions)==1) & (len(lon.dimensions)==1):
    if (lat.dimensions[0] == ncin.variables[options.var].dimensions[1]) & \
            (lon.dimensions[0] == ncin.variables[options.var].dimensions[2]):
                Lon, Lat = np.meshgrid(lon[:], lat[:])
    T, I, J = ncin.variables[options.var].shape
#elif (len(lat.dimensions) == 2) and (lat.dimensions == lon.dimensions):
#    assert (lat.dimensions == ncin.variables[options.var].dimensions[1:]) , \
#            "lat,lon dimensions don't match with %s." % options.var
#    Lat = lat[:]
#    Lon = lon[:]
else:
    import sys; sys.exit()

ncin.close()


ncin = Dataset(args[0], 'r')
try:
    add_offset = float(getattr(ncin.variables[varname], 'add_offset'))
except:
    add_offset = None

try:
    scale_factor = float(getattr(ncin.variables[varname], 'scale_factor'))
except:
    scale_factor = None
ncin.close()

# ============================================================================
if newfile is True:
    ncin = Dataset(args[0], 'r')
    ncout = Dataset(options.outputfile, 'w')

    # Global Attributes

    # Copying dimensions
    dims = ncin.variables[varname].dimensions
    for dim in dims:
        ncout.createDimension(dim, len(ncin.dimensions[dim]))

    # Copying variables related to the dimensions
    variables = {}

    for dim in dims:
        variables[dim] = ncout.createVariable(dim,
                ncin.variables[dim].datatype.name,
                ncin.variables[dim].dimensions)
        variables[dim][:] = ncin.variables[dim][:]
        #attributes = nc.variables[dim].ncattrs()
        #for a in attributes:
        for a in ncin.variables[dim].ncattrs():
            setattr(ncout.variables[dim], a, getattr(ncin.variables[dim], a))

    output = ncout.createVariable(varname,
            ncin.variables[varname].dtype.name,
            ncin.variables[varname].dimensions,
            fill_value=ncin.variables[varname]._FillValue)

    # If there is a scale factor, I would need to save the output
    #   considering that.
    attrs = ncin.variables[varname].ncattrs()
    attrs = [a for a in attrs if a != '_FillValue']

    data = ncin.variables[options.var]

    import multiprocessing as mp
    npesmax = 2 * mp.cpu_count() +1
    if (options.npes <= npesmax) & (options.npes > 0):
        npes = options.npes
    else:
        print "Considering the number of cpu on your machine, I'll stick with npe=%s"  % npesmax
        npes = npesmax

    print " Will work with %s npes" % npes

    pool = mp.Pool(npes)
    results = []

    print "Preparing the workers."
    if filtertype == 'highpass':
        l = options.shorterpasslength
    elif filtertype == 'lowpass':
        l = options.largerpasslength
    #elif filtertype = 'bandpass'

    for nt in range(T):
        results.append( pool.apply_async(
            wmean_2D_latlon,
            (Lat, Lon, data[nt], l, options.windowmethod,
                options.interp)
            ) )
    pool.close()
    print "Collecting the results."
    for nt, r in enumerate(results):
        try:
            pbar.update(nt)
        except:
            pass

        if filtertype == 'lowpass':
            tmp = r.get()
        elif filtertype == 'highpass':
            tmp = data[nt] - r.get()

        if add_offset is not None:
            tmp = tmp - add_offset
        if scale_factor is not None:
            tmp = tmp / scale_factor

        if output.dtype in ['int16']:
            output[nt] = tmp.round()
        else:
            output[nt] = tmp

    for a in attrs:
        setattr(ncout.variables[varname], a,
                getattr(ncin.variables[varname], a))

    # Saving the filter window size as an attribute of the output variable
    if filtertype == 'highpass':
        setattr(output, 'filter_window_size' , options.shorterpasslength)
    elif filtertype == 'lowpass':
        setattr(output, 'filter_window_size' , options.largerpasslength)
    else:
        setattr(output, 'filter_window_size' , options.l)
    # Saving the version as an attribute of the output variable
    setattr(output, 'maud_version' , ver_maud)
    # Saving the filtering date  as an attribute of the output variable
    setattr(output, 'maud_processing_date', datetime.now().isoformat())


    ncin.close()
    ncout.sync()
    ncout.close()
    import sys; sys.exit()


# Here goes if newfile is False
nc = Dataset(args[0], 'a')

#if var not in nc.variables.keys():
#    import sys; sys.exit()



attributes = nc.variables[options.var].ncattrs()
#nc.variables[options.var].missing_value
# Think well how to do it. Probably I should first delete the old one and create from scratch, so there is no risk of old attributes and other stuff left behind.
options.varout  =  "%s_maud2D" % options.var 
if (options.varout in nc.variables):
    print "Hey! %s is already in this file. I'll overwrite it" % \
        options.varout
    out = nc.variables[options.varout]
    if '_FillValue' in attributes:
        out[:] = nc.variables[options.var]._FillValue
        attributes.remove('_FillValue')
    elif 'missing_value' in attributes:
        out[:] = nc.variables[options.var].missing_value
else:
    try:
        out = nc.createVariable(options.varout, 
                  nc.variables[options.var].dtype, 
                  nc.variables[options.var].dimensions, 
                  fill_value=nc.variables[options.var]._FillValue)
        attributes.remove('_FillValue')
    except:
        out = nc.createVariable(options.varout, 
                  nc.variables[options.var].dtype, 
                  nc.variables[options.var].dimensions)

for a in attributes:
    setattr(out, a, getattr(nc.variables[options.var], a))

# Saving the filter window size as an attribute of the output variable
setattr(out, 'filter_window_size' , options.l)
# Saving the version as an attribute of the output variable
setattr(out, 'maud_version' , ver_maud)
# Saving the filtering date  as an attribute of the output variable
setattr(out, 'maud_processing_date', datetime.now().isoformat())



data = nc.variables[options.var]

try:
    from progressbar import ProgressBar
    pbar = ProgressBar(maxval=T).start()
except:
    print "ProgressBar is not available"

try:
    import multiprocessing as mp
    npesmax = 2 * mp.cpu_count() +1
    if (options.npes <= npesmax) & (options.npes > 0):
        npes = options.npes
    else:
        print "Considering the number of cpu on your machine, I'll stick with npe=%s"  % npesmax
        npes = npesmax

    print " Will work with %s npes" % npes
    #data_smooth = ma.empty(data.shape)
    pool = mp.Pool(npes)
    results = []

    print "Preparing the workers."
    for nt in range(T):
        results.append( pool.apply_async( wmean_2D_latlon, (Lat, Lon, data[nt], options.l, options.windowmethod, options.interp) ) )

    pool.close()
    print "Collecting the results."
    for nt, r in enumerate(results):
        try:
            pbar.update(nt)
        except:
            pass

        out[nt] = r.get()

except:
    print "Sorry, didn't work to filter in parallel"
    raise

#for nt in range(T):
#        try:
#            pbar.update(nt)
#        except:
#            pass
#
#        tmp = wmean_2D_latlon(Lat, Lon, data[nt], options.l, method = options.windowmethod)
#        out[nt] = tmp - data[nt]

nc.close()
