#!/usr/bin/env python2
# -*- coding: utf-8 -*-

import sys
import string
import logging
import os.path as op

from optparse import OptionParser
from pprint import pformat

import numpy as np
import pyhrf
from pyhrf.tools._io import read_volume, write_volume
from pyhrf.parcellation import parcellation_ward_spatial
from pyhrf.graph import (graph_from_lattice, kerMask3D_6n,
                         graph_from_mesh)
from pyhrf.ndarray import xndarray, stack_cuboids, MRI3Daxes

logger = logging.getLogger(__name__)

usage = 'usage: %%prog [options] SPM_FILES'
description = 'Build a functional parcellation from SPM GLM result files'

parser = OptionParser(usage=usage, description=description)

minArgs = 1
maxArgs = -1


parser.add_option('-o', '--output', dest='outFile', default=None,
                  help='Output file for the parcellation, default is %default')

parser.add_option('-v', '--verbose', dest='verbose', metavar='VERBOSELEVEL',
                  type='int', default=0,
                  help=pformat(pyhrf.verbose_levels))

parser.add_option('-m', '--mask', dest='maskFile', default=None,
                  metavar='MASK_FILE',
                  help='Functional mask file, default is mask.[nii|gii]')

parser.add_option('-g', '--mesh-file', dest='mesh_file', default=None,
                  metavar='MESH_FILE',
                  help='Mesh file, mandatory if SPM_FILES are surface textures.')

parser.add_option('-k', '--spatial-weight',
                  dest='spatial_weight', default=10, type='float',
                  metavar='FLOAT',
                  help='Weight of spatial coordinates relative to'
                  ' functional features (for methods gkm, '
                  'ward_and_gkm or ward methods).')

parser.add_option('-n', '--nb-parcels', dest='nparcels', default=None,
                  metavar='INT', type='int',
                  help='Number of parcels. Default is %default. '
                  'Do not use this option at the same time as'
                  'the option "size-parcel"')

parser.add_option('-d', '--dry-run', dest='dry', action='store_true',
                  default=False, help='Do not compute anything, '
                  'just print inputs.')

parser.add_option('-s', '--size-parcel', dest='sparcel',
                  default=None,
                  help='Size of one parcel. Default is 2700 mm3 for '
                  ' volumic data (cf. Thyreau06) or 280 mm2 for surfacic '
                  'data. Do not use this option at the same time as '
                  'the option "nb-parcel"')


choices = ['gkm', 'ward_and_gkm', 'ward', 'ward_spatial']
parser.add_option('-t', '--method', type='choice', choices=choices,
                  dest='method', metavar='STRING', default=choices[1],
                  help='Parcellation method. Choices are:\n %s, '
                  % string.join(choices, ', ') + 'default is %default')


(options, args) = parser.parse_args()

try:
    from pyhrf.parcellation import make_parcellation_from_files
except ImportError:
    print 'Nipy is not installed, this command is not available.'
    sys.exit(1)

# pyhrf.verbose.set_verbosity(options.verbose)
pyhrf.logger.setLevel(options.verbose)

nba = len(args)
if nba < minArgs or (maxArgs >= 0 and nba > maxArgs):
    parser.print_help()
    sys.exit(1)

from pyhrf.tools._io import has_ext_gzsafe, read_mesh


def is_surface_data_file(filename):
    return has_ext_gzsafe(filename, 'gii')


def is_volume_data_file(filename):
    return has_ext_gzsafe(filename, 'nii') or has_ext_gzsafe(filename, 'img')


if is_surface_data_file(args[0]):
    data_type = 'surface'
    options.sparcel = 300  # mmm2
    data_extension = 'gii'
    if options.mesh_file is None:
        raise Exception('Mesh file not provided, use [-g|--mesh-file] option.')

elif is_volume_data_file(args[0]):
    data_type = 'volume'
    options.sparcel = 2700  # mmm3
    data_extension = 'nii'
else:
    raise Exception('Unsupported data type %s' % args[0])


if data_type == 'volume':
    if options.maskFile is not None and not op.exists(options.maskFile):
        logger.info('Could not find mask %s', options.maskFile)
        logger.info('Construct it as non-null voxels in first T image ...')
        t0, h = read_volume(args[0])
        m = (t0 != 0).astype(np.int32)
        logger.info('Mask contains %s voxels', m.sum())
        write_volume(m, options.maskFile, h)

    fil, met = read_volume(args[0])
    # Size of one voxel
    X_size = met[1]['pixdim'][1]
    Y_size = met[1]['pixdim'][2]
    Z_size = met[1]['pixdim'][3]
    size_voxel = X_size * Y_size * Z_size  # in mm3

    # Total number of voxels
    nvox = (read_volume(options.maskFile)[0] > 0).sum()

elif data_type == 'surface':
    coord, triangles, mgii = read_mesh(options.mesh_file)
    nvox = len(coord)
    from pyhrf.tools import distance as dist

    def triangle_area(a, b, c):
        l1, l2, l3 = dist(a, b), dist(a, c), dist(b, c)
        p = (l1 + l2 + l3) / 2.
        return (p * (p - l1) * (p - l2) * (p - l3)) ** .5

    total_surface_area = sum([triangle_area(*coord[t]) for t in triangles])
    logger.info('total surface area: %f mm2', total_surface_area)

if options.nparcels is None:

    # Number of parcels chosen to ensure that one parcel as a size equal to
    # options.sparcel
    if data_type == 'volume':
        n = (size_voxel * nvox) / options.sparcel
    elif data_type == 'surface':
        n = total_surface_area / options.sparcel

    # Some rounding:
    # if n >= 100:
    #     options.nparcels = int(np.round(n/100)*100)
    # elif n >= 10:
    #     options.nparcels = int(np.round(n/10)*10)
    # else:
    #     options.nparcels = int(np.round(n))

    options.nparcels = int(np.round(n))
else:
    options.sparcel = (size_voxel * nvox) / options.nparcels

logger.info('Size of one parcel: %1.2f mm%s',
            options.sparcel, ['2', '3'][data_type == 'volume'])
logger.info('Nb of parcels: %d', options.nparcels)

if options.outFile is None:
    options.outFile = './parcel_GLM_%s_np%d.%s' \
        % (options.method, options.nparcels, data_extension)
    logger.info('Output file is %s', options.outFile)

if options.method == 'ward_spatial':

    if options.maskFile is not None:
        mask = xndarray.load(options.maskFile)
    else:
        mask = xndarray(np.ones(nvox, dtype=int), axes_names=['voxel'])

    mask.data = mask.data > 0

    # Load and remove dims with length=1 that are not 3D spatial axes:
    # for surface data, squeezing will have no effect
    loaded_data = [xndarray.load(fn).squeeze_all_but(MRI3Daxes)
                   for fn in args]

    logger.info('First data volume:')
    logger.info(loaded_data[0].descrip())

    # Flatten the data by using mask:
    if data_type == 'volume':
        flat_data = [c.cflatten(mask, new_axis='voxel') for c in loaded_data]
    else:
        flat_data = loaded_data

    func_data = stack_cuboids(flat_data, "feature")
    func_data.set_orientation(['voxel', 'feature'])
    logger.info('func_data:')
    logger.info(func_data.descrip())

    if data_type == 'volume':
        graph = graph_from_lattice(mask.data, kerMask3D_6n)
    elif data_type == 'surface':
        graph = graph_from_mesh(triangles)

    labels = parcellation_ward_spatial(func_data.data,
                                       n_clusters=options.nparcels, graph=graph)
    clabels = xndarray(labels, axes_names=['position'],
                       meta_data=mask.meta_data)
    clabels = clabels.cexpand(mask, axis='position')  # deflatten labels
    # print 'clabels'
    # print clabels.descrip()
    logger.info('Save parcellation output to %s', options.outFile)
    clabels.save(options.outFile)

else:
    if data_type != 'volume':
        raise Exception('Only volume data type supported for method %s'
                        % options.method)
    make_parcellation_from_files(args, options.maskFile, options.outFile,
                                 options.nparcels, options.method, options.dry,
                                 options.spatial_weight)
