#!/usr/bin/env python2
# -*- coding: utf-8 -*-

import sys
import logging
import os
import os.path as op

from optparse import OptionParser
from pprint import pformat

import pyhrf
from pyhrf.ndarray import xndarray

logger = logging.getLogger(__name__)

try:
    os.environ['DISPLAY']
except KeyError:
    # Non interactive terminal
    # Let's disable graphical backend for matplotlib
    import matplotlib
    matplotlib.use('Agg')

usage = 'usage: %%prog [options] VOL_FILE'
description = 'Plot slice(s) of a given volume '

parser = OptionParser(usage=usage, description=description)

minArgs = 1
maxArgs = 1


parser.add_option('-o', '--output', dest='output_tag', default=None,
                  help='Tag to construst the output file name as:'
                  '<tag>_<orientation>.png. Default is: '
                  '<input filename>_mip_<orientation>.png')

parser.add_option('-d', '--output_dir', dest='output_dir', default='./',
                  help='Output directory where to store outputs. '
                  'Default is %default')

parser.add_option('-a', '--ax', dest='axial_slice', default=None, type='int',
                  help='Axial slice index (in functional referential, '
                  'as provided by pyhrf_view)')

parser.add_option('-s', '--sag', dest='sagittal_slice', default=None, type='int',
                  help='Sagittal slice index (in functional referential, '
                  'as provided by pyhrf_view)')

parser.add_option('-c', '--cor', dest='coronal_slice', default=None, type='int',
                  help='Coronal slice index (in functional referential, '
                  'as provided by pyhrf_view)')

parser.add_option('-m', '--roi-mask', dest='mask_file', metavar='NIFTI_FILE',
                  default=None,
                  help="Mask file of ROIs to be contoured on the plots")

parser.add_option('-y', '--anatomy', dest='anatomy_file', metavar='NIFTI_FILE',
                  default=None,
                  help="Anatomy file on which to superimpose the functional "
                  "data.")

parser.add_option('-t', '--time-slice', dest='time_slice', type='int',
                  default=0,
                  help="Index of time slice to take for 4D vol. "
                  "Default is %default")

parser.add_option('-i', '--other-slice-index', dest='other_slices',
                  default=None, action='append',
                  help="Complementary indexing to define the slice in case of "
                  " ndims > 3. Format is <axis_name>:<axis_value>.")

parser.add_option('-v', '--verbose', dest='verbose', metavar='VERBOSELEVEL',
                  type='int', default=0,
                  help=pformat(pyhrf.verbose_levels))


(options, args) = parser.parse_args()

# pyhrf.verbose.set_verbosity(options.verbose)
pyhrf.logger.setLevel(options.verbose)

nba = len(args)

if nba < minArgs or (maxArgs >= 0 and nba > maxArgs):
    parser.print_help()
    sys.exit(1)

slice_orientations = {}
if options.axial_slice is not None:
    slice_orientations['axial'] = options.axial_slice

if options.sagittal_slice is not None:
    slice_orientations['sagittal'] = options.sagittal_slice

if options.coronal_slice is not None:
    slice_orientations['coronal'] = options.coronal_slice

other_slice_orientations = {}
if options.other_slices is not None:

    for slice_index in options.other_slices:
        axis_name, value = slice_index.split(':')
        if value.isdigit():
            value = int(value)
            # TODO: treat floats also
        other_slice_orientations[axis_name] = value

if len(slice_orientations) == 0:
    print ''
    print 'Error: at least one slice orientation is required'
    print ''
    parser.print_help()
    sys.exit(1)


vol_fn = args[0]

cfunc = xndarray.load(vol_fn)

if options.mask_file is not None:
    cparcellation = xndarray.load(options.mask_file)
else:
    cparcellation = None

if options.anatomy_file is not None:
    canat = xndarray.load(options.anatomy_file)
else:
    canat = None

import matplotlib.pyplot as plt

from pyhrf.ndarray import MRI3Daxes
from pyhrf.plot import plot_func_slice, autocrop, fig_orientations
tag = options.output_tag
if tag is None:
    tag = op.splitext(op.basename(vol_fn))[0]

for axis, islice in slice_orientations.iteritems():
    # print 'axis/slice:', axis, islice
    ori = {axis: islice}
    ori.update(other_slice_orientations)
    func_slice = cfunc.sub_cuboid(**ori)
    if func_slice.has_axis('time'):
        func_slice = func_slice.sub_cuboid(time=options.time_slice)

    ori = fig_orientations[axis]

    func_slice.set_orientation(ori)

    fdata = func_slice.data

    if cparcellation is not None:
        pdata = cparcellation.sub_cuboid(**{axis: islice}).reorient(ori).data
    else:
        pdata = None

    if canat is not None:
        scale_factor = func_slice.get_voxel_size(axis) /   \
            canat.get_voxel_size(axis)
        aislice = int(round(islice * scale_factor))
        anat_data = canat.sub_cuboid(**{axis: aislice}).reorient(ori).data
    else:
        anat_data = None

    plot_func_slice(fdata, anatomy=anat_data, parcellation=pdata)

    if len(other_slice_orientations) > 0:
        sother_ori = "_" + "_".join('%s_%s' % (str(a), str(v))
                                    for a, v in other_slice_orientations.items())
    else:
        sother_ori = ''
    output_fig_fn = op.join(options.output_dir,
                            '%s_%s_%d%s.png' % (tag, axis, islice, sother_ori))
    logger.info('Save to: %s', output_fig_fn)
    plt.savefig(output_fig_fn)
    autocrop(output_fig_fn)
