#!/usr/bin/env python2
# -*- coding: utf-8 -*-

import sys
import logging

from optparse import OptionParser
from pprint import pformat

import numpy as np

import pyhrf
from pyhrf.tools import add_suffix
from pyhrf.tools._io import write_volume, read_volume
from pyhrf.graph import kerMask3D_6n, split_mask_into_cc_iter

logger = logging.getLogger(__name__)

minArgs = 1
maxArgs = -1

usage = 'usage: %%prog [options] FILE'
description = 'Extract connected components from a 3D binary mask FILE.'\
    ' Ouputs will be in the form FILE_s%06d_cc%06d.nii, where "s%06d" is '\
    ' the size of the connected component.'

parser = OptionParser(usage=usage, description=description)

parser.add_option('-v', '--verbose', dest='verbose', metavar='VERBOSELEVEL',
                  type='int', default=0,
                  help=pformat(pyhrf.verbose_levels))

parser.add_option('-m', '--min-size', dest='min_size', metavar='INT',
                  type='int', default=10,
                  help='Minimum size of an extracted connected component.')


parser.add_option('-b', '--biggest', dest='keep_biggest', metavar='BOOL',
                  action='store_true', default=False,
                  help='Keep only the biggest CC.')

choices = ['6', '27']
parser.add_option('-c', '--connexity', dest='connexity', metavar='INT',
                  type='choice', default='6', choices=choices,
                  help='Connexity for the local neighbourhood system, '
                  'choices are: %s. Default is 6'
                  % ','.join(choices))

(options, args) = parser.parse_args()
# pyhrf.verbose.set_verbosity(options.verbose)
pyhrf.logger.setLevel(options.verbose)

nba = len(args)
if nba < minArgs or (maxArgs >= 0 and nba > maxArgs):
    parser.print_help()
    sys.exit(1)

km = None  # full connexity, ie 27 neighbours
if options.connexity == '6':
    km = kerMask3D_6n

input_mask_file = args[0]
img3D, meta = read_volume(input_mask_file)

labels = np.unique(img3D)
if len(labels) != 2 or all(labels != [0, 1]):
    print 'Input mask %s is not a binary mask' % input_mask_file
    sys.exit(1)

logger.info('Extract connected components ...')
if not options.keep_biggest:
    cc_iter = split_mask_into_cc_iter(img3D, options.min_size, km)
    for i, mcc in enumerate(cc_iter):
        size = mcc.sum()
        fn_out = add_suffix(input_mask_file, '_s%06d_cc%06d' % (size, i))
        write_volume(mcc, fn_out, meta)
    logger.info('Done. Extracted %d connected components.', i)
else:
    cc_iter = split_mask_into_cc_iter(img3D, options.min_size, km)
    size_max = 0
    for i, mcc in enumerate(cc_iter):
        size = mcc.sum()
        if size > size_max:
            mcc_biggest = mcc
            size_max = size

    fn_out = add_suffix(input_mask_file, '_s%06d' % (size_max))
    write_volume(mcc_biggest, fn_out, meta)
    logger.info('Done. Extracted the biggest connected '
                'component (size=%d).', size_max)
