#!/usr/bin/env python3
import os
import pandas as pd
import numpy as np
from biom import load_table
import click
from skbio.stats.composition import clr, clr_inv
from songbird.multinomial import MultRegression
from songbird.util import (read_metadata, match_and_filter, split_training,
                           silence_output)
from songbird.parameter_info import DESCS, DEFAULTS
import warnings

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=FutureWarning)
    from tensorboard.plugins.hparams import api as hp
    import tensorflow as tf

warnings.filterwarnings("ignore", category=RuntimeWarning)


@click.group()
def songbird():
    pass


@songbird.command()
@click.option(
    "--input-biom", show_default=True, required=True, help=DESCS["table"]
)
@click.option(
    "--metadata-file", show_default=True, required=True, help=DESCS["metadata"]
)
@click.option(
    "--formula", show_default=True, required=True, help=DESCS["formula"]
)
@click.option(
    "--training-column",
    default=DEFAULTS["training-column"],
    show_default=True,
    help=DESCS["training-column"],
)
@click.option(
    "--num-random-test-examples",
    default=DEFAULTS["num-random-test-examples"],
    show_default=True,
    help=DESCS["num-random-test-examples"],
)
@click.option(
    "--epochs",
    show_default=True,
    default=DEFAULTS["epochs"],
    help=DESCS["epochs"],
)
@click.option(
    "--batch-size",
    show_default=True,
    help=DESCS["batch-size"],
    default=DEFAULTS["batch-size"],
)
@click.option(
    "--differential-prior",
    show_default=True,
    default=DEFAULTS["differential-prior"],
    help=DESCS["differential-prior"],
)
@click.option(
    "--learning-rate",
    show_default=True,
    default=DEFAULTS["learning-rate"],
    help=DESCS["learning-rate"],
)
@click.option(
    "--clipnorm",
    show_default=True,
    default=DEFAULTS["clipnorm"],
    help=DESCS["clipnorm"],
)
@click.option(
    "--min-sample-count",
    show_default=True,
    default=DEFAULTS["min-sample-count"],
    help=DESCS["min-sample-count"],
)
@click.option(
    "--min-feature-count",
    show_default=True,
    default=DEFAULTS["min-feature-count"],
    help=DESCS["min-feature-count"],
)
@click.option(
    "--checkpoint-interval",
    show_default=True,
    default=DEFAULTS["checkpoint-interval"],
    help=DESCS["checkpoint-interval"],
)
@click.option(
    "--summary-interval",
    show_default=True,
    default=DEFAULTS["summary-interval"],
    help=DESCS["summary-interval"],
)
@click.option(
    "--summary-dir",
    default=DEFAULTS["summary-dir"],
    show_default=True,
    help=DESCS["summary-dir"],
)
@click.option(
    "--random-seed",
    default=DEFAULTS["random-seed"],
    show_default=True,
    help=DESCS["random-seed"],
    type=int,
)
@click.option(
    "--silent/--no-silent",
    default=DEFAULTS["silent"],
    show_default=True,
    help=DESCS["silent"],
)
def multinomial(
    input_biom,
    metadata_file,
    formula,
    training_column,
    num_random_test_examples,
    epochs,
    batch_size,
    differential_prior,
    learning_rate,
    clipnorm,
    min_sample_count,
    min_feature_count,
    checkpoint_interval,
    summary_interval,
    summary_dir,
    random_seed,
    silent,
):
    if silent:
        silence_output()

    # load metadata and tables
    metadata = read_metadata(metadata_file)
    table = load_table(input_biom)

    # match them
    table, metadata, design = match_and_filter(
        table, metadata, formula, min_sample_count, min_feature_count
    )

    # convert to dense representation
    dense_table = table.to_dataframe().to_dense().T

    hparams = {'input_biom': input_biom,
               'metadata_file': metadata_file,
               'formula': formula,
               'num_random_test_examples': num_random_test_examples,
               'epochs': epochs,
               'batch_size': batch_size,
               'differential_prior': differential_prior,
               'learning_rate': learning_rate,
               'min_sample_count': min_sample_count,
               'min_feature_count': min_feature_count,
               'silent': silent,
               }
    if random_seed is not None:
        hparams.update({
            'random_seed': random_seed,
        })

    # split up training and testing
    trainX, testX, trainY, testY = split_training(
        dense_table,
        metadata,
        design,
        training_column,
        num_random_test_examples,
        seed=random_seed,
    )

    # initialize and train the model
    model = MultRegression(
        learning_rate=learning_rate,
        clipnorm=clipnorm,
        beta_mean=differential_prior,
        batch_size=batch_size,
        save_path=summary_dir,
    )
    with tf.Graph().as_default(), tf.Session() as session:
        # set the tf random seed
        if random_seed is not None:
            tf.set_random_seed(random_seed)

        model(session, trainX, trainY, testX, testY)

        model.fit(
            epochs=epochs,
            summary_interval=summary_interval,
            checkpoint_interval=checkpoint_interval,
            silent=silent,
        )

        summary_writer = tf.contrib.summary.create_file_writer(summary_dir)

        with summary_writer.as_default(), \
                tf.contrib.summary.always_record_summaries():
            hps = hp.hparams(hparams)
            tf.contrib.summary.initialize(graph=session.graph)

        session.run(hps)

    md_ids = np.array(design.columns)
    obs_ids = table.ids(axis="observation")

    beta_ = clr(clr_inv(np.hstack((np.zeros((model.p, 1)), model.B))))

    df = pd.DataFrame(beta_.T, columns=md_ids, index=obs_ids)
    df.index.name = "featureid"
    df.to_csv(os.path.join(summary_dir, "differentials.tsv"), sep="\t")


if __name__ == "__main__":
    songbird()
