#!/usr/bin/env python2
# -*- coding: utf-8 -*-

import sys
import logging
import os.path as op

from optparse import OptionParser
from pprint import pformat

import numpy as np

import pyhrf
from pyhrf.tools._io.spmio import loadmat
from pyhrf.tools import add_suffix
from pyhrf.ndarray import xndarray, MRI4Daxes, stack_cuboids

logger = logging.getLogger(__name__)

usage = 'usage: %%prog [options] SPM.MAT '

description = 'Generate the stimulus induced signal encoded the design matrix.'\
              'Output shape is the same as functional data recorded ' \
              'in the input SPM.mat'
parser = OptionParser(usage=usage, description=description)

parser.add_option('-o', '--output', dest='output_file',
                  metavar='NII_FILE', default=None,
                  help='The output file')

# parser.add_option('-t','--target-volume',dest='target_volume',
#                   metavar='NII_FILE', default=None,
#                   help='A nifti file whose 3D geometry is used '\
#                        'for the output signal')

parser.add_option('-s', '--effect-sizes', dest='effect_sizes',
                  metavar='LIST OF FLOAT', default=None,
                  help='Comma separated list (no space) of effect sizes '
                       'to generate the stim induced signal. The length of '
                       'the list must the same as the number of regressors '
                       'in the design matrix. Note: effects are only relative')

parser.add_option('-f', '--stack-with-fmri', dest='stack_with_fmri',
                  metavar='BOOLEAN', default=False, action='store_true',
                  help='Also stack the input fMRI signal in a 5th dim.')

parser.add_option('-v', '--verbose', dest='verbose', metavar='VERBOSELEVEL',
                  type='int', default=0,
                  help=pformat(pyhrf.verbose_levels))

(options, args) = parser.parse_args()
pyhrf.verbose.set_verbosity(options.verbose)

# Treat result of option parsing:
if len(args) != 1:
    parser.print_help()
    sys.exit(1)

spm_mat_fn = args[0]
input_dir = op.dirname(spm_mat_fn)
if input_dir == '':
    input_dir = './'

spm = loadmat(spm_mat_fn)['SPM']

dmat = spm['xX'][0][0]['X'][0][0]  # works for SPM8,
# TODO: make a func, test for SPM5 & SPM12

regnames = [a[0] for a in spm['xX'][0][0]['name'][0][0][0]]

if options.effect_sizes is not None:
    effect_sizes = [float(s) for s in options.effect_sizes.split(',')]
    if len(effect_sizes) != dmat.shape[1]:
        raise Exception('length of effect size ({nbe}) is not the same as the '
                        'number of regressors in the design matrix ({nbr}).'
                        'Names of available regressors:\n{regnames}'.format(
                            **{'nbe': len(effect_sizes), 'nbr': dmat.shape[1],
                               'regnames': '\n'.join(regnames)}))
else:
    effect_sizes = np.ones(len(regnames))

logger.info('Regressor names and associated effect sizes:')
logger.info('\n'.join(['%s : %1.2f' % (n, e)
                       for n, e in zip(regnames, effect_sizes)]))

func_fn = spm[0][0]['xY']['VY'][0][0]['fname'][0][0][0]
target_vol = xndarray.load(func_fn)


stim_ind = np.ones_like(target_vol.data)[:, :, :, np.newaxis] * \
    np.dot(dmat, effect_sizes)

logger.info('Target geometry: %s', str(stim_ind.shape))

if options.output_file is None:
    output_fn = op.join(op.dirname(spm_mat_fn), './stim_induced_from_SPM.nii')
    if options.stack_with_fmri:
        output_fn = add_suffix(output_fn, '_with_fmri')
try:
    tr = spm[0][0]['xY']['RT'][0][0][0][0]
    taxis = np.arange(0, stim_ind.shape[3]) * tr
except ValueError:
    taxis = np.arange(0, stim_ind.shape[3])


xstim_ind = xndarray(stim_ind, axes_names=MRI4Daxes,
                     axes_domains={'time': taxis},
                     meta_data=target_vol.meta_data)
if not options.stack_with_fmri:
    logger.info('Save to: %s', output_fn)
    xstim_ind.save(output_fn)
else:
    # load all scans:
    fnames = [f[0][0] for f in spm[0][0]['xY']['VY'][0][0]['fname']]
    fmri = stack_cuboids([xndarray.load(fn) for fn in fnames],
                         'time', taxis).reorient(MRI4Daxes)
    # rescale stim induced to fmri signal:
    xstim_ind.data -= xstim_ind.mean('time').data[:, :, :, np.newaxis]
    xstim_ind.data *= fmri.ptp('time').data[:, :, :, np.newaxis] /       \
        xstim_ind.ptp('time').data[:, :, :, np.newaxis] * .9
    xstim_ind.data += fmri.mean('time').data[:, :, :, np.newaxis]
    logger.info('Save to: %s', output_fn)
    stack_cuboids(
        [fmri, xstim_ind], 'type', ['fmri', 'design']).save(output_fn)
