#!/usr/bin/env python

#    view *.tmi images for TFCE_mediation
#    Copyright (C) 2016  Tristram Lett

#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.

#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.

#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <http://www.gnu.org/licenses/>.

from __future__ import division
import os
import sys
import numpy as np
import warnings
import argparse as ap
import nibabel as nib
import matplotlib.pyplot as plt
import matplotlib.colors as colors

from tfce_mediation.tm_io import read_tm_filetype
from tfce_mediation.tm_func import print_tmi_history
from tfce_mediation.pyfunc import convert_voxel, convert_fs, write_colorbar, create_adjac_vertex, vectorized_surface_smooth

# hopefully safer loading of mayavi
if 'QT_API' not in os.environ:
	os.environ['QT_API'] = 'pyqt'
try:
	from mayavi import mlab
except:
	print "Trying pyside"
	os.environ['QT_API'] = 'pyside'
	from mayavi import mlab

DESCRIPTION = """
Display surfaces and volumes from a tmi file.

This script relies on Mayavi. If you use it please cite:

Ramachandran, P. and Varoquaux, G., `Mayavi: 3D Visualization of Scientific Data` IEEE Computing in Science & Engineering, 13 (2), pp. 40-51 (2011)
"""


def getArgumentParser(ap = ap.ArgumentParser(description = DESCRIPTION, formatter_class=ap.RawTextHelpFormatter)):

	ap.add_argument("-i", "-i_tmi", "--inputtmi",
		help="Input the *.tmi file containing the statistics to view.",
		nargs=1, 
		metavar='*.tmi',
		required=True)
	ap.add_argument("-oh", "--history",
		help="Output tmi file history and exits.", 
		action='store_true')
	ap.add_argument("-d", "--display",
		help="""
			Select which object to display. The mask, surface, contrast must be entered as integers (check values 
			with -oh). Multiple objects can be displayed. The input must be divisible by three.""",
		nargs = '+',
		type = int,
		metavar = 'int')
	# optional
	ap.add_argument("-dv", "--displayvoxelmask",
		help = "Display a volume as a surface. The mask and contrast must be entered as integers (check values with -oh).", 
		nargs = '+',
		type = int,
		metavar = 'int')

	ap.add_argument("-lut", "--lookuptable",
		help = """
			Set the color map to display. The lookuptable can be red-yellow (r_y), blue-lightblue (b_lb) or any 
			matplotlib colorschemes (https://matplotlib.org/examples/color/colormaps_reference.html)""", 
		type = str,
		default = ['r_y'],
		nargs = 1)
	ap.add_argument("-t", "--thresholds",
		help = "Set upper and lower thresholds", 
		default=[.95,1],
		type = float,
		nargs = 2)
	ap.add_argument("-a", "--alpha",
		help = "Set alpha [0 to 255]", 
		default=[255],
		type = int,
		nargs = 1)
	ap.add_argument("-o", "--opacity",
		help = "Set opacity [0 to 1]", 
		default=[1.0],
		type = float,
		nargs = 1)

	#new
	ap.add_argument("-ds", "--displaysurface",
		help = "Display a surface without a scalar (i.e., just the surface).", 
		nargs = '+',
		type = int,
		metavar = 'int')
	ap.add_argument("-iv", "--importvolume",
		help = "Import a volume (nifti or minc). e.g., -iv mean_FA_skeleton_mask.nii.gz", 
		nargs = '+',
		type = str,
		metavar = 'str')
	ap.add_argument("-biv", "--binarizeimportvolume",
		help = "test", 
		action = 'store_true')
	ap.add_argument("-ifs", "--importfreesurfer",
		help = """
			Import a freesurfer surface (use 'tm_multimodal create-tmi' to add other surface types). 
			e.g., -ifs $SUBJECT_DIR/fsaverage/surf/lh.midthickness""", 
		nargs = '+',
		type = str,
		metavar = 'str')
	ap.add_argument("-save", "--savesnapshots",
		help = "Save snapshots of the image. Input the basename of the output.", 
		nargs = 1,
		type = str,
		metavar = 'basename')
	ap.add_argument("--savetype",
		help = "Choose output snapshot type by file extension. Default is: %(default)s", 
		nargs = 1,
		type = str,
		default = ['png'],
		metavar = 'filetype')
	ap.add_argument("-ss","--surfacesmoothing",
		help = "Apply Laplician or Taubin smoothing before visualization. Input the numner of iterations (e.g., -ss 5)", 
		nargs = 1,
		type = int,
		metavar = 'int')
	ap.add_argument("-stype","--smoothingtype",
		help = "Set type of surface smoothing to use (%(choices)s). The default is laplacian", 
		nargs = 1,
		choices = ['laplacian','taubin'],
		default = ['laplacian'],
		metavar = 'str')


	return ap

def run(opts):
	# if not enough inputs, output history
	if len(sys.argv) <= 3:
		opts.history = True

	# load tmi
	_, image_array, masking_array, maskname_array, affine_array, vertex_array, face_array, surfname, _, tmi_history, columnids = read_tm_filetype(opts.inputtmi[0], verbose=False)
	if opts.history:
		num_con = image_array[0].shape[1]
		if num_con > 500: # safer
			num_con = None
		print_tmi_history(tmi_history, 
			maskname_array, 
			surfname, 
			num_con = num_con,
			contrast_names = columnids)
		quit()

	# get the positions of masked data in image_array
	pointer = 0
	position_array = [0]
	for i in range(len(masking_array)):
		pointer += len(masking_array[i][masking_array[i]==True])
		position_array.append(pointer)
	del pointer

	# make custom look-up table
	if (str(opts.lookuptable[0]) == 'r_y') or (str(opts.lookuptable[0]) == 'red-yellow'):
		cmap_array = np.array(( (np.ones(256)*255), np.linspace(0,255,256), np.zeros(256), np.ones(256)*255.0)).T
	elif (str(opts.lookuptable[0]) == 'b_lb') or (str(opts.lookuptable[0]) == 'blue-lightblue'):
		cmap_array = np.array(( np.zeros(256), np.linspace(0,255,256), (np.ones(256)*255), np.ones(256)*255.0)).T
	else:
		try:
			cmap_array = eval('plt.cm.%s(np.arange(256))' % opts.lookuptable[0])
		except:
			print "Error: Lookup table '%s' is not recognized." % opts.lookuptable[0]
			print "The lookup table can be red-yellow (r_y), blue-lightblue (b_lb) or any matplotlib colorschemes (https://matplotlib.org/examples/color/colormaps_reference.html)"
			quit()
		cmap_array *= 255
	cmap_array[0] = [227,218,201,opts.alpha[0]] # set lowest threshold to bone

	if opts.display:
		if len(opts.display) % 3 != 0:
			print "Error: There must be three inputs per surface (mask, surface, contrast)."
			quit()
		num_obj = int(len(opts.display) / 3)

		# display the surfaces
		for i in range(num_obj):
			c_mask = opts.display[(0 + int(i*3))]
			c_surf = opts.display[(1 + int(i*3))]
			c_contrast = opts.display[(2 + int(i*3))]

			start = position_array[c_mask]
			end = position_array[c_mask+1]

			mask = masking_array[c_mask]
			scalar_data = np.zeros((mask.shape[0]))
			scalar_data[mask[:,0,0]] = image_array[0][start:end,c_contrast]
			v = vertex_array[c_surf][:]
			f = face_array[c_surf][:]

			if opts.surfacesmoothing:
				adjacency = create_adjac_vertex(v,f)
				v, f, scalar_data = vectorized_surface_smooth(v, f, adjacency,
					number_of_iter = int(opts.surfacesmoothing[0]),
					scalar = scalar_data,
					mode = str(opts.smoothingtype[0]))

			surf = mlab.triangular_mesh(v[:,0], v[:,1], v[:,2], f,
				scalars = scalar_data,
				opacity = opts.opacity[0],
				vmin=opts.thresholds[0],
				vmax=opts.thresholds[1],
				name = maskname_array[c_mask])
			surf.module_manager.scalar_lut_manager.lut.table = cmap_array
			surf.actor.mapper.interpolate_scalars_before_mapping = 1

	if opts.displayvoxelmask:

		cmap_array[0] = [227,218,201,0]

		if len(opts.displayvoxelmask) % 2 != 0:
			print "Error: There must be two inputs per voxel surface (mask, contrast)."
			quit()

		for j in range(int(len(opts.displayvoxelmask)/2)):
			c_mask = opts.displayvoxelmask[0 + int(j*2)]
			c_contrast = opts.displayvoxelmask[1 + int(j*2)]

			start = position_array[c_mask]
			end = position_array[c_mask+1]

			mask = masking_array[c_mask]
			scalar_data = np.zeros((mask.shape[0],mask.shape[1],mask.shape[2]))
			scalar_data[mask] = image_array[0][start:end,c_contrast]
			v, f, scalar_data = convert_voxel(scalar_data, affine = affine_array[c_mask])

			if opts.surfacesmoothing:
				adjacency = create_adjac_vertex(v,f)
				v, f, scalar_data = vectorized_surface_smooth(v, f, adjacency,
					number_of_iter = int(opts.surfacesmoothing[0]),
					scalar = scalar_data,
					mode = str(opts.smoothingtype[0]))

			surf = mlab.triangular_mesh(v[:,0], v[:,1], v[:,2], f,
				scalars = scalar_data, 
				vmin=opts.thresholds[0],
				vmax=opts.thresholds[1],
				name = maskname_array[c_mask])
			surf.module_manager.scalar_lut_manager.lut.table = cmap_array
			surf.actor.mapper.interpolate_scalars_before_mapping = 1

	if opts.displaysurface:
		for c_surf in opts.displaysurface:

			v = vertex_array[c_surf]
			f = face_array[c_surf]

			if opts.surfacesmoothing:
				adjacency = create_adjac_vertex(v,f)
				v, f = vectorized_surface_smooth(v, f, adjacency,
					number_of_iter = int(opts.surfacesmoothing[0]),
					mode = str(opts.smoothingtype[0]))

			surf = mlab.triangular_mesh(v[:,0], v[:,1], v[:,2], f,
				opacity = opts.opacity[0],
				name = surfname[c_surf],
				color = (227/255, 218/255, 201/255))
			surf.actor.mapper.interpolate_scalars_before_mapping = 1

	if opts.importvolume:
		for volume in opts.importvolume:
			invol = nib.load(volume)
			data = invol.get_data()

			if opts.binarizeimportvolume:
				data[data != 0] = 1

			v, f, scalar_data = convert_voxel(data, affine = invol.get_affine())

			if opts.surfacesmoothing:
				adjacency = create_adjac_vertex(v,f)
				v, f, scalar_data = vectorized_surface_smooth(v, f, adjacency,
					number_of_iter = int(opts.surfacesmoothing[0]),
					scalar = scalar_data,
					mode = str(opts.smoothingtype[0]))

			if opts.binarizeimportvolume:
				scalar_data[scalar_data != 0] = 1

			surf = mlab.triangular_mesh(v[:,0], v[:,1], v[:,2], f,
				scalars = scalar_data, 
				name = volume,
				opacity = opts.opacity[0],
				color = (227/255, 218/255, 201/255))
			surf.actor.mapper.interpolate_scalars_before_mapping = 1

	if opts.importfreesurfer:
		for insurf in opts.importfreesurfer:
			v, f = convert_fs(insurf)

			if opts.surfacesmoothing:
				adjacency = create_adjac_vertex(v,f)
				v, f = vectorized_surface_smooth(v, f, adjacency,
					number_of_iter = int(opts.surfacesmoothing[0]),
					mode = str(opts.smoothingtype[0]))

			surf = mlab.triangular_mesh(v[:,0], v[:,1], v[:,2], f,
				name = insurf,
				opacity = opts.opacity[0],
				color = (227/255, 218/255, 201/255))
			surf.actor.mapper.interpolate_scalars_before_mapping = 1

	if opts.savesnapshots:

		surf.scene.background = (0,0,0)

		surf.scene.x_minus_view()
		orientation = 'left'
		mlab.savefig('%s_%s.%s' % (opts.savesnapshots[0], orientation, opts.savetype[0]), size = (1600,1400))
		surf.scene.x_plus_view()
		orientation = 'right'
		mlab.savefig('%s_%s.%s' % (opts.savesnapshots[0], orientation, opts.savetype[0]), size = (1600,1400))

		surf.scene.y_minus_view()
		orientation = 'posterior'
		mlab.savefig('%s_%s.%s' % (opts.savesnapshots[0], orientation, opts.savetype[0]), size = (1600,1400))
		surf.scene.y_plus_view()
		orientation = 'anterior'
		mlab.savefig('%s_%s.%s' % (opts.savesnapshots[0], orientation, opts.savetype[0]), size = (1600,1400))

		surf.scene.z_minus_view()
		orientation = 'inferior'
		mlab.savefig('%s_%s.%s' % (opts.savesnapshots[0], orientation, opts.savetype[0]), size = (1600,1400))
		surf.scene.z_plus_view()
		orientation = 'superior'
		mlab.savefig('%s_%s.%s' % (opts.savesnapshots[0], orientation, opts.savetype[0]), size = (1600,1400))

		surf.scene.isometric_view()
		orientation = 'isometric'
		mlab.savefig('%s_%s.%s' % (opts.savesnapshots[0], orientation, opts.savetype[0]), size = (1600,1400))

		rl_cmap = colors.ListedColormap(cmap_array[:,0:3]/255)
		write_colorbar(opts.thresholds, rl_cmap, opts.lookuptable[0])

	surf.scene.background = (0,0,0)
	mlab.show()

if __name__ == "__main__":
	parser = getArgumentParser()
	opts = parser.parse_args()
	run(opts)

