#!/usr/bin/env python

import fiona
import shapely
import numpy as np
import sys
import argparse

import climstats.dataset
import climstats.gridfunctions


# Read source/template shapefile
shpfile = fiona.open(sys.argv[1])
feature_count = len(shpfile)
print("{} features found".format(feature_count))

# Open source netcdf data file
#ncfile = xray.open_dataset(sys.argv[2])
#print ncfile.data_vars
#print ncfile.coords

ds = dataset.NetCDF4Dataset(sys.argv[2])
variable = ds.variables[sys.argv[3]]

print ds.variables
# Try and get the latitude/longitude variables
lats = variable.coords['latitude'][:]
lons = variable.coords['longitude'][:]

# Construct full 2D lat/lon grids
longrid, latgrid = np.meshgrid(lons, lats)

# Construct polygon grid
shape, grid_polys = gridfunctions.makegrid(lats, lons)
print("grid has shape {}".format(repr(shape)))

# Create the output shapefile based on the source shapefile schema
schema = shpfile.schema
print schema

outds = dataset.Dataset(dimensions=[('time', variable.coords['time'].shape[0]), ('feature', len(shpfile))])

tvar = dataset.Variable('time', outds, dimensions=['time'], dtype=np.float32)
tvar.attributes['units'] = variable.coords['time'].attributes['units']
tvar[:] = variable.coords['time'][:]

idvar = dataset.Variable('id', outds, dimensions=['feature'], dtype=np.int)

ancilvars = {}
for key, typestring in schema['properties'].items():

	typestring = typestring.split(':')[0]
	if typestring == 'str':
		dtype = object
	else:
		dtype = np.float32

	ancilvars[key] = dataset.Variable(key, outds, dimensions=['feature'], dtype=dtype)
	print ancilvars[key]

outvars = {}
for name, variable in ds.variables.items():
	outvars[name] = dataset.Variable(name, outds, dimensions=['time','feature'], dtype=variable.dtype)

print
print outds.variables
print outds.coords
print

# We'll keep all the weights, we could write these out as well
weights = np.zeros((len(shpfile), shape[0], shape[1]), dtype=np.float32)


# Okay here, we go, iterate through the template features
feature_id = 0
patches = []
for feature in shpfile:
	print feature['properties']

	# Get the actual geometry
	shape = shapely.geometry.shape(feature['geometry'])
	
	# Sweep the grid (this could be much more efficient by pre-culling based on geometry bounding box)
	x, y = 0, 0
	for row in grid_polys:
		for poly in row:
			if poly.intersects(shape):
				if poly.intersects(shape):
					weights[feature_id,y,x] = poly.intersection(shape).area / poly.area
			x += 1
		
		x = 0
		y += 1

	# Normalize to sum = 1
	weights[feature_id] = weights[feature_id]/np.sum(weights[feature_id])  

	for name, variable in ds.variables.items():
		try:
			print name
			axes = tuple(range(1,len(variable.shape)))
			print axes
			tmp = np.sum(variable[:] * weights[feature_id], axis=axes, keepdims=True)
			tmp = tmp.reshape((tmp.shape[0],1))
			print tmp.shape
			outvars[name][:,feature_id] = tmp
		except:
			pass


	idvar[feature_id] = feature_id

	for key, value in feature['properties'].items():
		print key, value
		ancilvars[key][feature_id] = value

	feature_id += 1

	# Calculate area weighted mean for each data variable
#	for varname in ncfile.data_vars.keys():
#		for ti in range(0, len(ncfile.coords['time'])):
#			value = np.ma.sum(ncfile[varname][ti] * weights[feature_id])
#			feature['properties']['{}_t{}'.format(varname, ti+1)] = float(value)
		
#		print feature['properties']
#		outshpfile.write(feature)


#outshpfile.close()
dataset.NetCDF4Dataset.write(outds, sys.argv[4])



