#!/usr/bin/env python2
# -*- coding: utf-8 -*-

import sys
import logging

from optparse import OptionParser
from pprint import pformat

import numpy as np
import pyhrf
from pyhrf.tools._io import read_volume, write_volume

from pyhrf.graph import graph_from_lattice, kerMask3D_6n
from pyhrf.tools import non_existent_file, add_suffix

logger = logging.getLogger(__name__)

usage = 'usage: %%prog [options] FILE'
description = 'Merge small parcels into other surrounding parcels. ' \
    'A given parel is merged with the neighboring parcel which has the highest '\
    'number connections to it.'

parser = OptionParser(usage=usage, description=description)

minArgs = 1
maxArgs = 1

parser.add_option('-o', '--output', dest='outFile', default=None,
                  help='Output file for the parcellation, default is '
                  'constructed from the input file by adding the '
                  'suffix "merged"')


parser.add_option('-m', '--min-parcel-size', dest='min_size', default=15,
                  metavar='INT', type='int',
                  help='Parcels below this size will be merged')


parser.add_option('-v', '--verbose', dest='verbose', metavar='VERBOSELEVEL',
                  type='int', default=0,
                  help=pformat(pyhrf.verbose_levels))


(options, args) = parser.parse_args()

# pyhrf.verbose.set_verbosity(options.verbose)
pyhrf.logger.setLevel(options.verbose)

nba = len(args)
if nba < minArgs or (maxArgs >= 0 and nba > maxArgs):
    parser.print_help()
    sys.exit(1)

min_size = options.min_size

parcellation_file = args[0]
parcellation, meta = read_volume(parcellation_file)
mask = (parcellation > 0).astype(int)
graph = graph_from_lattice(mask, kerMask3D_6n)

parcellation_flat = parcellation[np.where(mask)]
pmax = parcellation_flat.max()

plinks = np.zeros((pmax + 1, pmax + 1), dtype=int)

for i, nl in enumerate(graph):
    pi = parcellation_flat[i]
    for n in nl:
        pn = parcellation_flat[n]
        if pn != pi:
            plinks[pi, pn] += 1
            plinks[pn, pi] += 1

all_parcels = np.unique(parcellation_flat)
small_parcels = set(np.argwhere(
    np.bincount(parcellation_flat) <= min_size).flatten()).\
    intersection(all_parcels)

logger.info('Parcels to merge: %s', str(small_parcels))
for p_id in small_parcels:
    p_to_merge = np.argmax(plinks[p_id, :])
    logger.info('Merge parcel %d with %d (%d links):',
                p_id, p_to_merge, plinks[p_id, p_to_merge])
    parcellation_flat[np.where(parcellation_flat == p_id)] = p_to_merge
    plinks[p_to_merge, :] += plinks[p_id, :]
    plinks[p_id, :] = 0
    plinks[:, p_id] = 0

if options.outFile is None:
    suf = '_merged'
    fnout = non_existent_file(add_suffix(parcellation_file, suf))
    options.outFile = fnout

# unflatten
final_parcellation = np.zeros_like(parcellation)
final_parcellation[np.where(mask)] = parcellation_flat

logger.info('Save result to %s', options.outFile)
write_volume(final_parcellation.astype(np.int32), options.outFile, meta)
