#!/usr/bin/env python

#    view *.tmi images for TFCE_mediation
#    Copyright (C) 2016  Tristram Lett

#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.

#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.

#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <http://www.gnu.org/licenses/>.

from __future__ import division
import os
import sys
import numpy as np
import warnings
import argparse as ap
import nibabel as nib
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import warnings
import matplotlib.cbook

from tfce_mediation.tm_io import read_tm_filetype
from tfce_mediation.tm_func import print_tmi_history
from tfce_mediation.pyfunc import convert_voxel, convert_fs, write_colorbar, create_adjac_vertex, vectorized_surface_smooth

# hopefully safer loading of mayavi
if 'QT_API' not in os.environ:
	os.environ['QT_API'] = 'pyqt'
try:
	from mayavi import mlab
except:
	print "Trying pyside"
	os.environ['QT_API'] = 'pyside'
	from mayavi import mlab

def check_byteorder(np_array):
	if sys.byteorder == 'little':
		sys_bo = '<'
	elif sys.byteorder == 'big':
		sys_bo = '>'
	else:
		pass
	if np_array.dtype.byteorder != sys_bo:
		np_array = np_array.byteswap().newbyteorder()
	return np_array

def display_matplotlib_luts():
	# Adapted from https://matplotlib.org/1.2.1/examples/pylab_examples/show_colormaps.html

	# This example comes from the Cookbook on www.scipy.org. According to the
	# history, Andrew Straw did the conversion from an old page, but it is
	# unclear who the original author is.

	plt.switch_backend('Qt4Agg')
	warnings.filterwarnings("ignore",category=matplotlib.cbook.mplDeprecation)

	a = np.linspace(0, 1, 256).reshape(1,-1)
	a = np.vstack((a,a))

	maps = sorted(m for m in plt.cm.datad if not m.endswith("_r"))
	maps.append(u'red-yellow')
	maps.append(u'blue-lightblue')
	maps.append(u'green-lightgreen')
	nmaps = len(maps) + 1

	fig = plt.figure(figsize=(8,10))
	fig.subplots_adjust(top=0.99, bottom=0.01, left=0.2, right=0.99)
	for i,m in enumerate(maps):
		ax = plt.subplot(nmaps, 1, i+1)
		plt.axis("off")
		if m == 'red-yellow':
			cmap_array = np.array(( (np.ones(256)*255), np.linspace(0,255,256), np.zeros(256) )).T / 255
			plt.imshow(a, aspect='auto', cmap=colors.ListedColormap(cmap_array,m), origin='lower')
		elif m == 'blue-lightblue':
			cmap_array = np.array(( np.zeros(256), np.linspace(0,255,256), (np.ones(256)*255) )).T / 255
			plt.imshow(a, aspect='auto', cmap=colors.ListedColormap(cmap_array,m), origin='lower')
		elif m == 'green-lightgreen':
			cmap_array = np.array(( np.zeros(256), np.linspace(128,255,256), np.zeros(256) )).T / 255
			plt.imshow(a, aspect='auto', cmap=colors.ListedColormap(cmap_array,m), origin='lower')
		else:
			plt.imshow(a, aspect='auto', cmap=plt.get_cmap(m), origin='lower')
		pos = list(ax.get_position().bounds)
		fig.text(pos[0] - 0.01, pos[1], m, fontsize=10, horizontalalignment='right')
	plt.show()


def get_cmap_array(lut, alpha = 255, zero_lower = True, zero_upper = False, base_color = [227,218,201], c_reverse = False):
	if len(base_color) == 3:
		base_color.append(int(alpha))
	if lut.endswith('_r'):
		c_reverse = lut.endswith('_r')
		lut = lut[:-2]
	# make custom look-up table
	if (str(lut) == 'r-y') or (str(lut) == 'red-yellow'):
		cmap_array = np.array(( (np.ones(256)*255), np.linspace(0,255,256), np.zeros(256), np.ones(256)*255.0)).T
	elif (str(lut) == 'b-lb') or (str(lut) == 'blue-lightblue'):
		cmap_array = np.array(( np.zeros(256), np.linspace(0,255,256), (np.ones(256)*255), np.ones(256)*255.0)).T
	elif (str(lut) == 'g-lg') or (str(lut) == 'green-lightgreen'):
		cmap_array = np.array(( np.zeros(256), np.linspace(0,255,256), (np.zeros(256)*255), np.ones(256)*255.0)).T
	else:
		try:
			cmap_array = eval('plt.cm.%s(np.arange(256))' % lut)
		except:
			print "Error: Lookup table '%s' is not recognized." % lut
			print "The lookup table can be red-yellow (r_y), blue-lightblue (b_lb) or any matplotlib colorschemes (https://matplotlib.org/examples/color/colormaps_reference.html)"
			sys.exit()
		cmap_array *= 255
	if c_reverse:
		cmap_array = cmap_array[::-1]
	if zero_lower:
		cmap_array[0] = base_color
	if zero_upper:
		cmap_array[-1] = base_color
	return cmap_array

DESCRIPTION = """
Display surfaces and volumes from a tmi file.

This script relies on Mayavi. If you use it please cite:

Ramachandran, P. and Varoquaux, G., `Mayavi: 3D Visualization of Scientific Data` IEEE Computing in Science & Engineering, 13 (2), pp. 40-51 (2011)
"""

def getArgumentParser(ap = ap.ArgumentParser(description = DESCRIPTION)):
	group = ap.add_mutually_exclusive_group(required=True)
	group.add_argument("-i", "-i_tmi", "--inputtmi",
		help="Input the *.tmi file containing the statistics to view.",
		nargs=1, 
		metavar='*.tmi')
	group.add_argument("--no-tmi",
		help="Do not input a TMI file.",
		action = 'store_true')

	ap.add_argument("-oh", "--history",
		help="Output tmi file history and exits.", 
		action='store_true')
	ap.add_argument("-d", "--display",
		help="""
			Select which object to display. The mask, surface, contrast must be entered as integers (check values 
			with -oh). Multiple objects can be displayed. The input must be divisible by three.""",
		nargs = '+',
		type = int,
		metavar = 'int')
	# optional
	ap.add_argument("-dv", "--displayvoxelmask",
		help = "Display a volume as a surface. The mask and contrast must be entered as integers (check values with -oh).", 
		nargs = '+',
		type = int,
		metavar = 'int')

	ap.add_argument("-lut", "--lookuptable",
		help = """
			Set the color map to display. The lookuptable can be red-yellow (r-y), blue-lightblue (b-lb), green-lightgreen (g-lg) or any 
			matplotlib colorschemes. Any lut can be reverse by appending _r (e.g. -lut red-yellow_r). Use --plotluts to see the available lookuptables""", 
		type = str,
		default = ['r-y'],
		nargs = 1)
	group.add_argument("--plotluts",
		help = "Plots the avalable lookup tables, and exits.",
		action = 'store_true')
	ap.add_argument("-t", "--thresholds",
		help = "Set lower and upper thresholds. Default is: %(default)s", 
		default=[.95,1],
		type = float,
		nargs = 2)
	ap.add_argument("-a", "--alpha",
		help = "Set alpha [0 to 255]", 
		default=[255],
		type = int,
		nargs = 1)
	ap.add_argument("-o", "--opacity",
		help = "Set opacity [0 to 1]", 
		default=[1.0],
		type = float,
		nargs = 1)

	#new
	ap.add_argument("-ds", "--displaysurface",
		help = "Display a surface without a scalar (i.e., just the surface).", 
		nargs = '+',
		type = int,
		metavar = 'int')
	ap.add_argument("-iv", "--importvolume",
		help = "Import volume(s) (nifti or minc). e.g., -iv mean_FA_skeleton_mask.nii.gz", 
		nargs = '+',
		type = str,
		metavar = 'str')
	ap.add_argument("-ivb", "--binarizedimportvolume",
		help = "Import volume(s) (nifti or minc) and set all non-zero values to one. e.g., -biv mean_FA_skeleton_mask.nii.gz", 
		nargs = '+',
		type = str,
		metavar = 'str')
	ap.add_argument("-ivt", "--thresholdimportvolume",
		help = """
			Import a volume (nifti or minc) and apply a lower and upper threshold. e.g., -ivt mean_FA_skeleton_mask.nii.gz  0.2 0.8. 
			Only one volume can be added. A lookup table can be specified as the forth argument (the default is the -lut option).
		""", 
		nargs = '*',
		metavar = ['str|float|float|str'])
	ap.add_argument("-ifs", "--importfreesurfer",
		help = """
			Import a freesurfer surface (use 'tm_multimodal create-tmi' to add other surface types). 
			e.g., -ifs $SUBJECT_DIR/fsaverage/surf/lh.midthickness""", 
		nargs = '+',
		type = str,
		metavar = 'str')
	ap.add_argument("-fsmgh", "--importfreesurfermgh",
		help = """
			Import a freesurfer scalar (mgh) file(s) that match the input surface(s). The number of inputs 
			must match the -ifs inputs. e.g., -ifs ?h.midthickness -fsmgh ?h_results_file.mgh
			""", 
		nargs = '+',
		type = str,
		metavar = 'str')
	ap.add_argument("-save", "--savesnapshots",
		help = "Save snapshots of the image. Input the basename of the output.", 
		nargs = 1,
		type = str,
		metavar = 'basename')
	ap.add_argument("--savetype",
		help = "Choose output snapshot type by file extension. Default is: %(default)s", 
		nargs = 1,
		type = str,
		default = ['png'],
		metavar = 'filetype')
	ap.add_argument("-ss","--surfacesmoothing",
		help = "Apply Laplician or Taubin smoothing before visualization. Input the number of iterations (e.g., -ss 5).", 
		nargs = 1,
		type = int,
		metavar = 'int')
	ap.add_argument("-stype","--smoothingtype",
		help = """Set type of surface smoothing to use (choices are: %(choices)s). The default is laplacian
				The Taubin (aka low-pass) filter smooths curves/surfaces without the shrinkage of the laplacian filter.
				""", 
		nargs = 1,
		choices = ['laplacian','taubin'],
		default = ['laplacian'],
		metavar = 'str')


	return ap

def run(opts):

	if opts.plotluts:
		display_matplotlib_luts()
		sys.exit()

	# if not enough inputs, output history
	if len(sys.argv) <= 3:
		if opts.inputtmi:
			opts.history = True
		else:
			print "Error: no surface or volume imports (use -iv or -ifs)."
			sys.exit()

	# load tmi
	if opts.history:
		num_con = image_array[0].shape[1]
		if num_con > 500: # safer
			num_con = None
		print_tmi_history(tmi_history, 
			maskname_array, 
			surfname, 
			num_con = num_con,
			contrast_names = columnids)
		sys.exit()

	if opts.inputtmi:
		_, image_array, masking_array, maskname_array, affine_array, vertex_array, face_array, surfname, _, tmi_history, columnids = read_tm_filetype(opts.inputtmi[0], verbose=False)
		# get the positions of masked data in image_array
		pointer = 0
		position_array = [0]
		for i in range(len(masking_array)):
			pointer += len(masking_array[i][masking_array[i]==True])
			position_array.append(pointer)
		del pointer

	# make custom look-up table
	cmap_array = get_cmap_array(opts.lookuptable[0], float(opts.alpha[0]))

	if opts.display:
		if len(opts.display) % 3 != 0:
			print "Error: There must be three inputs per surface (mask, surface, contrast)."
			sys.exit()
		num_obj = int(len(opts.display) / 3)

		# display the surfaces
		for i in range(num_obj):
			c_mask = opts.display[(0 + int(i*3))]
			c_surf = opts.display[(1 + int(i*3))]
			c_contrast = opts.display[(2 + int(i*3))]

			start = position_array[c_mask]
			end = position_array[c_mask+1]

			mask = masking_array[c_mask]
			scalar_data = np.zeros((mask.shape[0]))
			scalar_data[mask[:,0,0]] = image_array[0][start:end,c_contrast]
			v = vertex_array[c_surf][:]
			f = face_array[c_surf][:]

			if opts.surfacesmoothing:
				adjacency = create_adjac_vertex(v,f)
				v, f, scalar_data = vectorized_surface_smooth(v, f, adjacency,
					number_of_iter = int(opts.surfacesmoothing[0]),
					scalar = scalar_data,
					mode = str(opts.smoothingtype[0]))

			surf = mlab.triangular_mesh(v[:,0], v[:,1], v[:,2], f,
				scalars = scalar_data,
				opacity = opts.opacity[0],
				vmin=opts.thresholds[0],
				vmax=opts.thresholds[1],
				name = maskname_array[c_mask])
			surf.module_manager.scalar_lut_manager.lut.table = cmap_array
			surf.actor.mapper.interpolate_scalars_before_mapping = 1

	if opts.displayvoxelmask:

		cmap_array[0] = [227,218,201,0]

		if len(opts.displayvoxelmask) % 2 != 0:
			print "Error: There must be two inputs per voxel surface (mask, contrast)."
			sys.exit()

		for j in range(int(len(opts.displayvoxelmask)/2)):
			c_mask = opts.displayvoxelmask[0 + int(j*2)]
			c_contrast = opts.displayvoxelmask[1 + int(j*2)]

			start = position_array[c_mask]
			end = position_array[c_mask+1]

			mask = masking_array[c_mask]
			scalar_data = np.zeros((mask.shape[0],mask.shape[1],mask.shape[2]))
			scalar_data[mask] = image_array[0][start:end,c_contrast]
			v, f, scalar_data = convert_voxel(scalar_data, affine = affine_array[c_mask])

			if opts.surfacesmoothing:
				adjacency = create_adjac_vertex(v,f)
				v, f, scalar_data = vectorized_surface_smooth(v, f, adjacency,
					number_of_iter = int(opts.surfacesmoothing[0]),
					scalar = scalar_data,
					mode = str(opts.smoothingtype[0]))

			surf = mlab.triangular_mesh(v[:,0], v[:,1], v[:,2], f,
				scalars = scalar_data, 
				vmin=opts.thresholds[0],
				vmax=opts.thresholds[1],
				name = maskname_array[c_mask])
			surf.module_manager.scalar_lut_manager.lut.table = cmap_array
			surf.actor.mapper.interpolate_scalars_before_mapping = 1

	if opts.displaysurface:
		for c_surf in opts.displaysurface:

			v = vertex_array[c_surf]
			f = face_array[c_surf]

			if opts.surfacesmoothing:
				adjacency = create_adjac_vertex(v,f)
				v, f = vectorized_surface_smooth(v, f, adjacency,
					number_of_iter = int(opts.surfacesmoothing[0]),
					mode = str(opts.smoothingtype[0]))

			surf = mlab.triangular_mesh(v[:,0], v[:,1], v[:,2], f,
				opacity = opts.opacity[0],
				name = surfname[c_surf],
				color = (227/255, 218/255, 201/255))
			surf.actor.mapper.interpolate_scalars_before_mapping = 1

	if opts.importvolume:
		for volume in opts.importvolume:
			invol = nib.load(volume)
			data = check_byteorder(invol.get_data())
			v, f, scalar_data = convert_voxel(data, affine = invol.get_affine())

			if opts.surfacesmoothing:
				adjacency = create_adjac_vertex(v,f)
				v, f, scalar_data = vectorized_surface_smooth(v, f, adjacency,
					number_of_iter = int(opts.surfacesmoothing[0]),
					scalar = scalar_data,
					mode = str(opts.smoothingtype[0]))

			if np.mean(scalar_data[scalar_data != 0]) == 1:
				surf = mlab.triangular_mesh(v[:,0], v[:,1], v[:,2], f,
					scalars = scalar_data, 
					name = os.path.basename(volume),
					opacity = opts.opacity[0],
					color = (227/255, 218/255, 201/255))
			else:
				surf = mlab.triangular_mesh(v[:,0], v[:,1], v[:,2], f,
					scalars = scalar_data, 
					vmin = opts.thresholds[0],
					vmax = opts.thresholds[1],
					name = os.path.basename(volume))
				cmap_array[0] = [227,218,201,0]
				surf.module_manager.scalar_lut_manager.lut.table = cmap_array
			surf.actor.mapper.interpolate_scalars_before_mapping = 1

	if opts.binarizedimportvolume:
		for volume in opts.binarizedimportvolume:
			invol = nib.load(volume)
			data = check_byteorder(invol.get_data())

			data[data != 0] = 1 # binary sandwich
			v, f, scalar_data = convert_voxel(data, affine = invol.get_affine())
			scalar_data[scalar_data != 0] = 1

			surf = mlab.triangular_mesh(v[:,0], v[:,1], v[:,2], f,
				scalars = scalar_data, 
				name = os.path.basename(volume),
				opacity = opts.opacity[0],
				color = (227/255, 218/255, 201/255))
			surf.actor.mapper.interpolate_scalars_before_mapping = 1

	if opts.thresholdimportvolume:
		if len(opts.thresholdimportvolume) == 3:
			vol_cmap_array = get_cmap_array(opts.lookuptable[3], 0)
			vol_occupancy = opts.opacity[0]
		elif len(opts.thresholdimportvolume) == 4:
			vol_cmap_array = get_cmap_array(opts.thresholdimportvolume[3], 0)
			vol_occupancy = opts.opacity[0]
		elif len(opts.thresholdimportvolume) == 5:
			vol_cmap_array = get_cmap_array(opts.thresholdimportvolume[3], 0)
			vol_occupancy = float(opts.thresholdimportvolume[4])
		else:
			print "Error -ivt must have three or four inputs"
			sys.exit()
		invol = nib.load(str(opts.thresholdimportvolume[0]))
		data = check_byteorder(invol.get_data())
		v, f, scalar_data = convert_voxel(data, affine = invol.get_affine())

		if opts.surfacesmoothing:
			adjacency = create_adjac_vertex(v,f)
			v, f, scalar_data = vectorized_surface_smooth(v, f, adjacency,
				number_of_iter = int(opts.surfacesmoothing[0]),
				scalar = scalar_data,
				mode = str(opts.smoothingtype[0]))

		surf = mlab.triangular_mesh(v[:,0], v[:,1], v[:,2], f,
			scalars = scalar_data, 
			vmin = float(opts.thresholdimportvolume[1]),
			vmax = float(opts.thresholdimportvolume[2]),
			opacity = vol_occupancy,
			name = os.path.basename(str(opts.thresholdimportvolume[0])))
		surf.module_manager.scalar_lut_manager.lut.table = vol_cmap_array
		surf.actor.mapper.interpolate_scalars_before_mapping = 1

	if opts.importfreesurfer:
		for insurf in opts.importfreesurfer:
			v, f = convert_fs(insurf)
			if opts.importfreesurfermgh:
				mghvol = opts.importfreesurfermgh[int(np.argwhere(np.array(opts.importfreesurfer)==insurf))] # super hacky
				if len(opts.importfreesurfer) != len(opts.importfreesurfermgh):
					print "The number of -fsmgh inputs must match the -ifs inputs."
					sys.exit()
				invol = nib.load(mghvol).get_data()
				if invol.ndim != 3: 
					print "Scalar (MGH) files with %d dimensions are not supported (hint: there should only be one subject/contrast)." % invol.ndim
				scalar_data = check_byteorder(np.squeeze(invol)) # annoying endianess issue fix

				if opts.surfacesmoothing:
					adjacency = create_adjac_vertex(v,f)
					v, f, scalar_data = vectorized_surface_smooth(v, f, adjacency,
						number_of_iter = int(opts.surfacesmoothing[0]),
						scalar = scalar_data,
						mode = str(opts.smoothingtype[0]))

				surf = mlab.triangular_mesh(v[:,0], v[:,1], v[:,2], f,
					scalars = scalar_data, 
					vmin = opts.thresholds[0],
					vmax = opts.thresholds[1],
					name = mghvol)
				
				surf.module_manager.scalar_lut_manager.lut.table = cmap_array

			else:
				if opts.surfacesmoothing:
					adjacency = create_adjac_vertex(v,f)
					v, f = vectorized_surface_smooth(v, f, adjacency,
						number_of_iter = int(opts.surfacesmoothing[0]),
						mode = str(opts.smoothingtype[0]))

				surf = mlab.triangular_mesh(v[:,0], v[:,1], v[:,2], f,
					name = insurf,
					opacity = opts.opacity[0],
					color = (227/255, 218/255, 201/255))
			surf.actor.mapper.interpolate_scalars_before_mapping = 1

	if opts.savesnapshots:

		surf.scene.background = (0,0,0)

		surf.scene.x_minus_view()
		orientation = 'left'
		mlab.savefig('%s_%s.%s' % (opts.savesnapshots[0], orientation, opts.savetype[0]), size = (1600,1400))
		surf.scene.x_plus_view()
		orientation = 'right'
		mlab.savefig('%s_%s.%s' % (opts.savesnapshots[0], orientation, opts.savetype[0]), size = (1600,1400))

		surf.scene.y_minus_view()
		orientation = 'posterior'
		mlab.savefig('%s_%s.%s' % (opts.savesnapshots[0], orientation, opts.savetype[0]), size = (1600,1400))
		surf.scene.y_plus_view()
		orientation = 'anterior'
		mlab.savefig('%s_%s.%s' % (opts.savesnapshots[0], orientation, opts.savetype[0]), size = (1600,1400))

		surf.scene.z_minus_view()
		orientation = 'inferior'
		mlab.savefig('%s_%s.%s' % (opts.savesnapshots[0], orientation, opts.savetype[0]), size = (1600,1400))
		surf.scene.z_plus_view()
		orientation = 'superior'
		mlab.savefig('%s_%s.%s' % (opts.savesnapshots[0], orientation, opts.savetype[0]), size = (1600,1400))

		surf.scene.isometric_view()
		orientation = 'isometric'
		mlab.savefig('%s_%s.%s' % (opts.savesnapshots[0], orientation, opts.savetype[0]), size = (1600,1400))

		rl_cmap = colors.ListedColormap(cmap_array[:,0:3]/255)
		write_colorbar(opts.thresholds, rl_cmap, opts.lookuptable[0])

	surf.scene.background = (0,0,0)
	mlab.show()

if __name__ == "__main__":
	parser = getArgumentParser()
	opts = parser.parse_args()
	run(opts)

