#!/usr/bin/env python2
# -*- coding: utf-8 -*-

import sys
import string
import logging

from pprint import pformat
from optparse import OptionParser
from time import time

import numpy as np

import pyhrf

from pyhrf.tools._io import read_volume, write_volume
from pyhrf.ndarray import expand_array_in_mask
from pyhrf.graph import graph_from_lattice
from pyhrf.parcellation import split_parcel
from pyhrf.tools import non_existent_file, add_suffix, format_duration
from pyhrf.graph import split_mask_into_cc_iter

logger = logging.getLogger(__name__)

usage = 'usage: %%prog [options] FILE'
description = 'Build a parcellation from spatial criterion. FILE can be' \
    ' a nifti binary mask file or a gifti mesh file. Method is a '\
    'balanced spatial partitionning which provides a partition ' \
    'comprising connected components with equal sizes.'

parser = OptionParser(usage=usage, description=description)

minArgs = 1
maxArgs = -1

parser.add_option('-o', '--output', dest='outFile', default=None,
                  help='Output file for the parcellation, default is '
                  'constructed from the input file by adding the '
                  'suffix "parcellation"')

parser.add_option('-s', '--parcel-size', dest='psize', default=200,
                  metavar='INT', type='int',
                  help='Define the parcel size. Ignored if option '
                  '"nb-parcels" is provided, default is %default.')

parser.add_option('-n', '--nb-parcels', dest='nparcels', default=None,
                  metavar='INT', type='int',
                  help='Number of parcels. If not provided then use '
                  ' option "parcel-size" to determine it.')

parser.add_option('-c', '--voronoi-seeds', dest='voronoi_seeds',
                  default=None, help='Image Volume containing seeds (integer '
                  ' labels) for the Voronoi parcellation. The number of '
                  'parcels is defined by them so "--nb-parcels" and '
                  '"--parcel-size" and ignored')

choices = ['balanced', 'voronoi', 'arbitrary']
parser.add_option('-m', '--method', type='choice', choices=choices,
                  dest='method', metavar='STRING', default=choices[0],
                  help='Parcellation method. Choices are:\n %s, '
                  % string.join(choices, ', ') + 'default is %default')


parser.add_option('-v', '--verbose', dest='verbose', metavar='VERBOSELEVEL',
                  type='int', default=0,
                  help=pformat(pyhrf.verbose_levels))


(options, args) = parser.parse_args()

pyhrf.logger.setLevel(options.verbose)
# pyhrf.verbose.set_verbosity(options.verbose)

nba = len(args)
if nba < minArgs or (maxArgs >= 0 and nba > maxArgs):
    parser.print_help()
    sys.exit(1)

mask_file = args[0]
mask, mask_header = read_volume(mask_file)
mask = (mask > 0).astype(np.int32)

nvox = mask.sum()
if options.nparcels is None:
    options.nparcels = np.round(nvox/options.psize)
else:
    options.psize = int(np.round(nvox*1./options.nparcels))

logger.info('Size of input mask: %d', nvox)

logger.info('Nb of parcels: %d, parcel size: ~%d',
            options.nparcels, options.psize)

t0 = time()
if options.method == 'balanced':

    # TODO: if g is not connex, try with full connexity
    g = graph_from_lattice(mask)
    # TODO: use pyhrf_extract_cc_vol
    mask_labels = mask[np.where(mask)]

    logger.info('Splitting ...')

    parcellation = np.zeros_like(mask)
    for cc_mask in split_mask_into_cc_iter(mask):
        logger.info('Treating a connected component (CC) of %d positions',
                    cc_mask.sum())
        g = graph_from_lattice(cc_mask)

        cc_np = max(int(np.round(options.nparcels * cc_mask.sum() / (nvox*1.))),
                    1)
        logger.info('Split (CC) into %d parcels', cc_np)
        cc_labels = np.ones(cc_mask.sum(), dtype=int)
        if cc_np > 1:
            split_parcel(cc_labels, {1: g}, 1, cc_np, inplace=True,
                         verbosity=2, balance_tolerance='draft')
        else:
            cc_labels[:] = 1
        logger.info('Split done!')

        maxp = parcellation.max()
        parcellation += expand_array_in_mask(cc_labels + maxp, cc_mask > 0)


elif options.method == 'voronoi':
    try:
        from nipy.algorithms.clustering.clustering import voronoi
    except ImportError:
        from nipy.algorithms.clustering.utils import voronoi
    from pyhrf.tools import peelVolume3D

    parcellation = np.zeros_like(mask)
    if options.voronoi_seeds is not None:
        all_seeds, _ = read_volume(options.voronoi_seeds)

    for cc_mask in split_mask_into_cc_iter(mask):
        logger.info('Treating a connected component (CC) of %d positions',
                    cc_mask.sum())
        if cc_mask.sum() < 6:
            continue
        if options.voronoi_seeds is None:
            # perform voronoi on random seeds
            eroded_mask = peelVolume3D(cc_mask)
            eroded_mask_size = eroded_mask.sum()
            if eroded_mask_size == 0:
                eroded_mask_size = nvox
                eroded_mask = mask.copy()
            seeds = np.random.randint(0, eroded_mask_size, options.nparcels)
            mask_for_seed = np.zeros(eroded_mask_size, dtype=int)
            mask_for_seed[seeds] = 1
            mask_for_seed = expand_array_in_mask(mask_for_seed, eroded_mask)
        else:
            mask_for_seed = all_seeds * cc_mask

        logger.info('Nb of seeds in current CC: %d', mask_for_seed.sum())
        cc_parcellation = voronoi(np.vstack(np.where(cc_mask)).T,
                                  np.vstack(np.where(mask_for_seed)).T) + 1
        logger.info('CC parcellation labels: %s',
                    str(np.unique(cc_parcellation)))
        maxp = parcellation.max()
        parcellation += expand_array_in_mask(cc_parcellation + maxp, cc_mask)
        logger.info('Current parcellation labels: %s',
                    str(np.unique(parcellation)))
    logger.info('voronoi parcellation: %s, %s',
                str(parcellation.shape), str(parcellation.dtype))
    # split non-connex parcels
elif options.method == 'arbitrary':
    parcellation = np.zeros_like(mask)
    pids = np.arange(1, options.nparcels+1, dtype=np.int32)
    print 'pids:', pids
    print 'mask:', mask.sum()

    psize = np.ceil(nvox*1./(options.nparcels))
    print 'np.repeat(pids, psize):', np.repeat(pids, psize).shape
    print 'parcellation[np.where(mask)]', parcellation[np.where(mask)].shape
    parcellation[np.where(mask)] = np.repeat(pids, psize)[:mask.sum()]

logger.info('Parcellation done in %s', format_duration(time()-t0))


if options.outFile is None:
    suf = '_parcellation_%s' %options.method
    fnout = non_existent_file(add_suffix(mask_file, suf))
    options.outFile = fnout

logger.info('Save result to %s', options.outFile)
write_volume(parcellation.astype(np.int32), options.outFile, mask_header)
