#!python

# Copyright (C) 2018, 2019 Simon Dirmeier
#
# This file is part of pybda.
#
# pybda is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# pybda is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with pybda. If not, see <http://www.gnu.org/licenses/>.
#
# @author = 'Simon Dirmeier'
# @email = 'simon.dirmeier@bsse.ethz.ch'


import logging

import click
import snakemake
import yaml

from pybda.config.config_checks import check_args
from pybda.globals import DIM_RED__, CLUSTERING__, REGRESSION__
from pybda.logger import logger_format


def _run(config, spark, cmd):
    from pybda import snake_file
    snakemake.snakemake(
      lock=False,
      unlock=False,
      snakefile=snake_file(),
      targets=cmd,
      configfile=config,
      config={"sparkip": spark}
    )


def _methods(conf, str):
    check_args(conf, str)
    return conf[str].replace(" ", "").split(",")


def _get_from_config(config, key):
    with open(config, 'r') as fh:
        conf_ = yaml.load(fh, Loader=yaml.FullLoader)
    if not isinstance(key, list):
        key = [key]

    methods = []
    for method in key:
        if method in conf_:
            methods += _methods(conf_, method)

    return methods


@click.group()
def cli():
    logging.basicConfig(format=logger_format())


@cli.command()
@click.argument("config", type=str)
@click.argument("spark", type=str)
def sample(config, spark):
    """
    Subsample a data set down to a specified fraction
     from a CONFIG in a SPARK session.
    """

    _run(config, spark, "sample")


@cli.command()
@click.argument("config", type=str)
@click.argument("spark", type=str)
def dimension_reduction(config, spark):
    """
    Computes a dimension reduction from a CONFIG in a SPARK session.
    """

    cl = _get_from_config(config, DIM_RED__)
    _run(config, spark, cl)


@cli.command()
@click.argument("config", type=str)
@click.argument("spark", type=str)
def clustering(config, spark):
    """
    Do a clustering fit from a CONFIG in a SPARK session.
    """

    cl = _get_from_config(config, CLUSTERING__)
    _run(config, spark, cl)


@cli.command()
@click.argument("config", type=str)
@click.argument("spark", type=str)
def regression(config, spark):
    """
    Fit a regression model from a CONFIG in a SPARK session.
    """

    cl = _get_from_config(config, REGRESSION__)
    _run(config, spark, cl)


@cli.command()
@click.argument("config", type=str)
@click.argument("spark", type=str)
def run(config, spark):
    """
    Execute all tasks defined in a CONFIG in a SPARK session.
    """

    methods = [REGRESSION__, DIM_RED__, CLUSTERING__]
    cl = _get_from_config(config, methods)
    _run(config, spark, cl)


if __name__ == "__main__":
    cli()
