#!/usr/bin/env python2
# -*- coding: utf-8 -*-

import sys
import logging
import os.path as op

from optparse import OptionParser
from pprint import pformat

import numpy as np
import pyhrf
from pyhrf.ndarray import xndarray, MRI3Daxes, MRI4Daxes
from pyhrf.plot import autocrop, fig_orientations, plot_cub_as_curve
import matplotlib.pyplot as plt

logger = logging.getLogger(__name__)


def plot_mips(cub, show_colbar=False, group_views=False, roi_mask=None,
              hrf=None, plot_hcano=False):

    hrf = xndarray(np.array([[0, .5, 1, .75, .5, -.2, 0., 0., 0],
                             [.5, 1, .75, .5, -.2, 0., 0., 0., 0]]),
                   axes_names=['type', 'time'],
                   axes_domains={'time': np.arange(9) * 2.,
                                 'type': ['BOLD', 'perfusion']})
    hrf = None

    res = dict((a, cub.get_voxel_size(a)) for a in MRI3Daxes)

    cm = plt.get_cmap(options.colormap)
    norm = plt.Normalize(vmin=cub.min(), vmax=cub.max())

    figs = {}
    if not group_views:
        for i, axis in enumerate(MRI3Daxes):
            mip = cub.max(axis)
            orientation = fig_orientations[axis]
            mip.set_orientation(orientation)
            data = np.flipud(mip.data)
            figs['mip_%s' % axis] = plt.figure()
            ax = plt.imshow(data, interpolation='nearest', cmap=cm, norm=norm)
            imax = np.unravel_index(data.argmax(), data.shape)
            print 'imax:', imax, cub.max(i)
            plt.plot(imax[1], imax[0], 'x')

            if roi_mask is not None:
                mip_roi = roi_mask.max(axis)
                mip_roi.set_orientation(orientation)
                plt.contour(np.flipud(mip_roi.data).astype(int), 1,
                            colors=['blue'], linewidths=3., alpha=.6,)

            ax.get_axes().set_axis_off()
            ax.get_axes().set_aspect(res[orientation[0]] / res[orientation[1]])

            if show_colbar:
                ax.colorbar()
    else:

        if 1:
            print 'res:'
            print res
            print 'shape:'
            print cub.get_dshape()

        brain_p = res['coronal'] * cub.get_dshape()['coronal']
        brain_h = res['axial'] * cub.get_dshape()['axial']
        brain_w = res['sagittal'] * cub.get_dshape()['sagittal']
        if roi_mask is None:
            gap_w = -(brain_w + brain_p) * .02
            gap_h = -(brain_h + brain_p) * .03
        else:  # if ROI countour is plotted then we have to change
              # the placement of axes ... dunno why ...
            gap_w = (brain_w + brain_p) * .02
            gap_h = (brain_h + brain_p) * .03

        gap_bot = 10.
        gap_top = 0.
        gap_left = 0.
        gap_right = 0.
        fig_h = brain_p + gap_h + brain_h + gap_top + gap_bot
        fig_w = brain_w + gap_w + brain_p + gap_left + gap_right

        def plot_roi(mpl_axis, axis, ext):
            if roi_mask is not None:
                orientation = fig_orientations[axis]
                mip_roi = roi_mask.max(axis)
                mip_roi.set_orientation(orientation)
                mpl_axis.contour(mip_roi.data.astype(int), 1, colors=['blue'],
                                 linewidths=3., alpha=.6, extent=ext)

        print 'fig_h:', fig_h
        print 'brain_p:', brain_p
        print 'brain_w:', brain_h
        print 'fig_w:', fig_w
        fig = plt.figure(figsize=(fig_w / 40, fig_h / 40), dpi=150)
        figs['mips'] = fig

        # Coronal
        axis = 'coronal'
        mip = cub.max(axis)
        orientation = fig_orientations[axis]
        mip.set_orientation(orientation)
        ext = [0, brain_w, 0, brain_h]
        ax_cor = plt.axes([gap_left / fig_w, (brain_p + gap_h + gap_bot) / fig_h,
                           brain_w / fig_w, brain_h / fig_h])
        ax_cor.imshow(mip.data, interpolation='nearest', extent=ext,
                      origin='lower', cmap=cm, norm=norm)
        imax = np.unravel_index(mip.data.argmax(), mip.data.shape)
        print 'mip:', mip.data.argmin(), mip.data.argmax(), mip.mean()
        print 'imax:', imax
        print 'res[sagittal]:', imax[1] * res['sagittal']
        ax_cor.plot(imax[1] * res['sagittal'], imax[0] * res['axial'],
                    'x', color='red')
        plot_roi(ax_cor, 'coronal', ext)
        ax_cor.get_axes().set_axis_off()

        # sagittal
        axis = 'sagittal'
        mip = cub.max(axis)
        orientation = fig_orientations[axis]
        mip.set_orientation(orientation)
        ext = [0, brain_p, 0, brain_h]
        ax_sag = plt.axes([(brain_w + gap_w + gap_left) / fig_w,
                           (brain_p + gap_h + gap_bot) / fig_h,
                           brain_p / fig_w, brain_h / fig_h], sharey=ax_cor)
        ax_sag.imshow(mip.data, interpolation='nearest', extent=ext,
                      origin='lower', cmap=cm, norm=norm)
        imax = np.unravel_index(mip.data.argmax(), mip.data.shape)
        ax_sag.plot(imax[1] * res['coronal'], imax[0] * res['axial'], 'x',
                    color='red')
        plot_roi(ax_sag, 'sagittal', ext)
        ax_sag.get_axes().set_axis_off()

        # Axial
        axis = 'axial'
        mip = cub.max(axis)
        orientation = fig_orientations[axis]
        mip.set_orientation(orientation)
        ext = [0, brain_w, 0, brain_p]

        ax_ax = plt.axes([gap_left / fig_w, gap_bot / fig_h,
                          brain_w / fig_w, brain_p / fig_h], sharex=ax_cor)
        print 'ax_ax:', ax_ax._position
        img_ax = ax_ax.imshow(mip.data, interpolation='nearest', extent=ext,
                              origin='lower', cmap=cm, norm=norm)
        imax = np.unravel_index(mip.data.argmax(), mip.data.shape)
        ax_ax.plot(imax[1] * res['sagittal'], imax[0] * res['coronal'], 'x',
                   color='red')
        plot_roi(ax_ax, 'axial', ext)
        ax_ax.get_axes().set_axis_off()

        cbar_width = 0.
        if options.colorbar:
            col_bar_width = (brain_w / 15.) / fig_w
            ax_col_bar = plt.axes([(gap_left + brain_w + gap_w) / fig_w,
                                   (gap_bot) / fig_h, col_bar_width,
                                   brain_p * 1. / fig_h])
            plt.colorbar(img_ax, cax=ax_col_bar)
            cbar_width = .08

        if hrf is not None:
            ax_sag = plt.axes([(brain_w + gap_w + gap_left) / fig_w +
                               cbar_width + .05,
                               (gap_bot) / fig_h,
                               brain_p / fig_w - cbar_width - .075,
                               brain_p / fig_h])
            plot_cub_as_curve(hrf, axes=ax_sag, show_legend=True)

    return figs

usage = 'usage: %%prog [options] VOL_FILE'
description = 'Plot Maximum Intensity Maps of a given volume '

parser = OptionParser(usage=usage, description=description)

minArgs = 1
maxArgs = 1


parser.add_option('-t', '--output-tag', dest='output_tag', default=None,
                  help='Output tag default is '
                  '<input_file>_mip_<orientation>.png for multiple images '
                  'or <input_file>_mips.png if all MIPs are grouped within '
                  'one image with the option "group-views".')

parser.add_option('-o', '--output-dir', dest='output_dir', default='./',
                  metavar='PATH',
                  help='Output directory. Default is %default.')

parser.add_option('-p', '--colormap', dest='colormap', default='Greys_r',
                  metavar='Matplotlib colormap name',
                  help='Color map for image plots. Default is %default.')

parser.add_option('-m', '--mask', dest='mask_file', metavar='MASK_FILE',
                  default=None,
                  help="Mask file defining the contour of the brain")

parser.add_option('-r', '--roi-mask', dest='roi_mask_file', metavar='MASK_FILE',
                  default=None,
                  help="Mask file of the ROI to be contoured on the plots")


parser.add_option('-v', '--verbose', dest='verbose', metavar='VERBOSELEVEL',
                  type='int', default=0,
                  help=pformat(pyhrf.verbose_levels))

parser.add_option('-c', '--colorbar',
                  default=False, action='store_true',
                  help='Plot colorbar')


parser.add_option('-g', '--goup-views', dest='group',
                  default=False, action='store_true',
                  help='Group all views in a single image')

(options, args) = parser.parse_args()

# pyhrf.verbose.set_verbosity(options.verbose)
pyhrf.logger.setLevel(options.verbose)

nba = len(args)
if nba < minArgs or (maxArgs >= 0 and nba > maxArgs):
    parser.print_help()
    sys.exit(1)

vol_fn = args[0]

cv = xndarray.load(vol_fn)
if cv.get_ndims() > 3:
    cv.set_orientation(MRI4Daxes)
else:
    cv.set_orientation(MRI3Daxes)

cmask = None
if options.mask_file is not None:
    cmask = xndarray.load(options.mask_file).squeeze()
    cmask.data = cmask.data > 0
    cmask.set_orientation(MRI3Daxes)

cmask_roi = None
if options.roi_mask_file is not None:
    cmask_roi = xndarray.load(options.roi_mask_file)
    cmask_roi.set_orientation(MRI3Daxes)

tag = options.output_tag
print 'tag:', tag
if tag is None:
    tag = op.splitext(op.basename(vol_fn))[0]
    print 'vol_fn:', vol_fn
print 'tag32:', tag


def save_figs(figs, ftag):
    for k, f in figs.iteritems():
        output_fig_fn = op.join(options.output_dir, '%s_%s.png' % (ftag, k))
        logger.info('Save to: %s', output_fig_fn)
        f.savefig(output_fig_fn)
        autocrop(output_fig_fn)

if cv.get_ndims() < 4:

    if cmask is not None:
        cv.data = np.ma.masked_array(cv.data,
                                     np.bitwise_not(
                                         cmask.data.astype(np.bool)),
                                     fill_value=-1000.)

    figs = plot_mips(cv, show_colbar=options.colorbar,
                     group_views=options.group, roi_mask=cmask_roi)
    save_figs(figs, tag)

else:
    figs = plot_mips(cv.mean('time'), show_colbar=options.colorbar,
                     group_views=options.group)
    save_figs(figs, tag + '_mean_time')

    plot_mips(cv.var('time'), show_colbar=options.colorbar,
              group_views=options.group)
    save_figs(figs, tag + '_var_time')
