#!/usr/bin/env python3
import os
import tensorflow as tf
import pandas as pd
import numpy as np
from biom import load_table
import click
from skbio.stats.composition import clr, clr_inv
from songbird.multinomial import MultRegression
from songbird.util import read_metadata, match_and_filter, split_training


@click.group()
def songbird():
    pass


@songbird.command()
@click.option('--input-biom', show_default=True, required=True,
              help='Input abundances')
@click.option('--metadata-file', show_default=True, required=True,
              help='Sample metadata table with covariates of interest.')
@click.option('--formula', show_default=True, required=True,
              help=('Statistical formula specifying the covariates '
                    'to test for.'))
@click.option('--training-column', default=None, show_default=True,
              help=('The column in the metadata file used to '
                    'specify training and testing. These columns '
                    'should be specifically labeled (Train) and (Test)'))
@click.option('--num-random-test-examples', default=5, show_default=True,
              help=('Number of random training examples if '
                    '--training-column is not specified'))
@click.option('--epochs', show_default=True, default=1000,
              help= ('The number of total number of iterations '
                     'over the entire dataset'))
@click.option('--batch-size', show_default=True,
              help='Size of mini-batch', default=5)
@click.option('--differential-prior', show_default=True, default=1.,
              help=('Width of normal prior for the `differentials`, or '
                    'the coefficients of the multinomial regression. '
                    'Smaller values will force the coefficients towards zero. '
                    'Values must be greater than 0.'))
@click.option('--learning-rate', show_default=True, default=1e-3,
              help=('Gradient descent decay rate.'))
@click.option('--clipnorm', show_default=True, default=10.,
              help=('Gradient clipping size.'))
@click.option('--min-sample-count', show_default=True, default=1000,
              help=("The minimum number of counts a sample needs "
                    "for it to be included in the analysis"))
@click.option('--min-feature-count', show_default=True, default=5,
              help=("The minimum number of counts a feature needs for "
                    "it to be included in the analysis"))
@click.option('--checkpoint-interval', show_default=True, default=3600,
              help='Number of seconds before a saving a checkpoint')
@click.option('--summary-interval', show_default=True, default=60,
              help='Number of seconds before a storing a summary.')
@click.option('--summary-dir', default='summarydir', show_default=True,
              help='Summary directory to save regression results. '
              'This will include a table of differentials under '
              '`differentials.tsv` that can be ranked, in addition '
              'to summaries that can be loaded into Tensorboard and '
              'checkpoints for recovering parameters during runtime.')
def multinomial(input_biom, metadata_file, formula, training_column,
                num_random_test_examples, epochs, batch_size,
                differential_prior, learning_rate, clipnorm, min_sample_count,
                min_feature_count, checkpoint_interval, summary_interval,
                summary_dir):
    # load metadata and tables
    metadata = read_metadata(metadata_file)
    table = load_table(input_biom)

    # match them
    table, metadata, design = match_and_filter(
        table, metadata, formula,
        min_sample_count, min_feature_count)

    # convert to dense representation
    dense_table = table.to_dataframe().to_dense().T

    trainX, testX, trainY, testY = split_training(
        dense_table, metadata, design,
        training_column,
        num_random_test_examples)

    # split up training and testing
    model = MultRegression(learning_rate=learning_rate,
                           clipnorm=clipnorm,
                           beta_mean=differential_prior,
                           batch_size=batch_size,
                           save_path=summary_dir)
    with tf.Graph().as_default(), tf.Session() as session:
        model(session, trainX, trainY, testX, testY)

        model.fit(
            epochs=epochs,
            summary_interval=summary_interval,
            checkpoint_interval=checkpoint_interval)

    md_ids = np.array(design.columns)
    obs_ids = table.ids(axis='observation')

    beta_ = clr(clr_inv(np.hstack((np.zeros((model.p, 1)), model.B))))

    pd.DataFrame(
        beta_.T, columns=md_ids, index=obs_ids,
    ).to_csv(os.path.join(summary_dir, 'differentials.tsv'), sep='\t')


if __name__ == '__main__':
    songbird()
