#!/usr/bin/env python

# Copyright (c) 2015, Ecole Polytechnique Federale de Lausanne, Blue Brain Project
# All rights reserved.
#
# This file is part of NeuroM <https://github.com/BlueBrain/NeuroM>
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
#     1. Redistributions of source code must retain the above copyright
#        notice, this list of conditions and the following disclaimer.
#     2. Redistributions in binary form must reproduce the above copyright
#        notice, this list of conditions and the following disclaimer in the
#        documentation and/or other materials provided with the distribution.
#     3. Neither the name of the copyright holder nor the names of
#        its contributors may be used to endorse or promote products
#        derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

'''Examples of extracting basic statistics'''
import collections
import argparse
import json
import logging
import os
import sys

from functools import partial
import numpy as np

from neurom.stats import scalar_stats
from neurom.core.types import NEURITES
from neurom.ezy import load_neuron, load_neurons
from neurom.io.utils import get_morph_files

L = logging.getLogger(__name__)


NEURITE_STATS = [
    'get_section_lengths',
    'get_segment_lengths',
    'get_local_bifurcation_angles',
    'get_remote_bifurcation_angles',
    'get_n_sections_per_neurite',
    'get_n_sections',
    'get_n_neurites'
]


NEURON_STATS = [
    'get_soma_radius',
    'get_soma_surface_area',
]


def parse_args():
    '''Parse command line arguments'''
    parser = argparse.ArgumentParser(description='Morphology statistics extractor',
                                     epilog='Note: Outputs json')

    parser.add_argument('datapath',
                        help='Path to a morphology data file or a directory')

    parser.add_argument('-v', '--verbose', action='count', dest='verbose', default=0,
                        help='-v for INFO, -vv for DEBUG')

    parser.add_argument('--as-population',
                        action='store_true',
                        default=False,
                        help='If enabled the directory is treated as a population')

    return parser.parse_args()


def _flatten_array(l):
    '''Flattens an array of any depth.
    '''
    if isinstance(l, collections.Iterable):
        for el in l:
            if isinstance(el, collections.Iterable) and not isinstance(el, basestring):
                for sub in _flatten_array(el):
                    yield sub
            else:
                yield el
    else:
        yield l


def pfunc(ns, n=None):
    ''' Partially fills the arguments of the function that is applied on
    neurons
    '''
    if n is None:
        return partial(lambda x, ns: getattr(x, ns)(), ns=ns)
    else:
        return partial(lambda x, ns, n: getattr(x, ns)(n), ns=ns, n=n)


def eval_stats(func, neurons):
    ''' Apply the input function on the neuron objects, flatten the array if nested,
    make it a single element array if the result is not iterable and extract the stats
    '''

    value = np.fromiter(_flatten_array(func(nrn) for nrn in neurons), np.float)

    stat_functions = ('min', 'max', 'median', 'mean', 'std')
    try:
        value = scalar_stats(value)
    except ValueError:
        value = {key: 0. for key in stat_functions}

    return value


def extract_stats(neurons):
    '''Extract stats from neurons'''

    stats = {}
    for ns in NEURITE_STATS:
        stat_name = ns[4:]
        stats[stat_name] = {}
        for n in NEURITES:
            stats[stat_name][n.name] = eval_stats(pfunc(ns, n), neurons)
            L.debug('Stat: %s, Neurite: %s, Type: %s', ns, n, type(stats[stat_name][n.name]))

    for ns in NEURON_STATS:
        stats[ns[4:]] = eval_stats(pfunc(ns), neurons)

    return stats


if __name__ == '__main__':
    args = parse_args()
    logging.basicConfig(level=(logging.WARNING,
                               logging.INFO,
                               logging.DEBUG)[min(args.verbose, 2)])

    _f = args.datapath
    _results = {}

    if os.path.isfile(_f):
        _results[_f] = extract_stats((load_neuron(_f),))
    elif os.path.isdir(_f):
        if not args.as_population:
            for _p in get_morph_files(_f):
                _results[_p] = extract_stats((load_neuron(_p),))
        else:
            _results[_f] = extract_stats(load_neurons(_f))
    else:
        L.error("Invalid data path %s", _f)
        sys.exit(1)

    print json.dumps(_results, indent=2, separators=(',', ':'))
