#!python

# Copyright (C) 2018, 2019 Simon Dirmeier
#
# This file is part of pybda.
#
# pybda is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# pybda is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with pybda. If not, see <http://www.gnu.org/licenses/>.
#
# @author = 'Simon Dirmeier'
# @email = 'simon.dirmeier@bsse.ethz.ch'


import logging

import click
import snakemake
import yaml

from pybda.logger import logger_format


def run(config, spark, cmd):
    from pybda import snake_file
    snakemake.snakemake(
      snakefile=snake_file(),
      targets=[cmd],
      configfile=config,
      config={"sparkip": spark}
    )


def _get_from_config(config, key):
    with open(config, 'r') as fh:
        conf_ = yaml.load(fh)
    if key not in conf_.keys():
        raise ValueError("'{}' argument not found in config file".format(key))
    return conf_[key]


@click.group()
def cli():
    logging.basicConfig(format=logger_format())


@cli.command()
@click.argument("config", type=str)
@click.argument("spark", type=str)
def dimension_reduction(config, spark):
    """
    Computes a dimension reduction from a CONFIG in a SPARK session.
    """

    dr = _get_from_config(config, "dimension_reduction")
    run(config, spark, dr)


@cli.command()
@click.argument("config", type=str)
@click.argument("spark", type=str)
def clustering(config, spark):
    """
    Do a clustering fit of a data set.
    """

    cl = _get_from_config(config, "clustering")
    run(config, spark, cl)


@cli.command()
@click.argument("config", type=str)
@click.argument("spark", type=str)
def regression(config, spark):
    """
    Fit a regression model to a data set
    """

    cl = _get_from_config(config, "regression")
    run(config, spark, cl)


@cli.command()
@click.argument("config", type=str)
@click.argument("spark", type=str)
@click.argument("input", type=str)
@click.argument("output", type=str)
@click.argument("n", type=int)
@click.option("--split", type=bool, default=True)
def sample(config, spark, input, output, n, split):
    """
    Take a sample from a data set from a CONFIG in a SPARK session. In addition
    an INPUT and OUTPUT as well as the number of samples N and if features
    should be SPLIT must be provided.
    """

    from pybda import snake_file
    snakemake.snakemake(
      snakefile=snake_file(),
      targets=["sample"],
      configfile=config,
      config={"sparkip": spark,
              "input": input,
              "output": output,
              "n": str(n),
              "split": str(split)})


if __name__ == "__main__":
    cli()
