#!/usr/bin/env python3
import os
import time
from tqdm import tqdm
import pandas as pd
import numpy as np
from biom import load_table, Table
from biom.util import biom_open
from skbio.stats.composition import clr, centralize, closure
from skbio.stats.composition import clr_inv as softmax
import matplotlib.pyplot as plt
from scipy.stats import entropy, spearmanr
import click
from scipy.sparse import coo_matrix
import tensorflow as tf
from tensorflow.contrib.distributions import Multinomial, Normal
import datetime
from rhapsody.multimodal import Autoencoder, cross_validation
from rhapsody.util import onehot, rank_hits, random_multimodal, split_tables


@click.group()
def rhapsody():
    pass


@rhapsody.command()
@click.option('--microbe-file',
              help='Input microbial abundances')
@click.option('--metabolite-file',
              help='Input metabolite abundances')
@click.option('--metadata-file', default=None,
              help='Input sample metadata file')
@click.option('--training-column',
              help=('Column in the sample metadata specifying which '
                    'samples are for training and testing.'),
              default=None)
@click.option('--num-testing-examples',
              help=('Number of samples to randomly select for testing'),
              default=10)
@click.option('--min-feature-count',
              help=('Minimum number of samples a microbe needs to be observed '
                    'in order to not filter out'),
              default=10)
@click.option('--epochs',
              help='Number of epochs to train', default=10)
@click.option('--batch_size',
              help='Size of mini-batch', default=32)
@click.option('--latent_dim',
              help=('Dimensionality of shared latent space. '
                    'This is analogous to the number of PC axes.'),
              default=3)
@click.option('--input_prior',
              help=('Width of normal prior for input embedding.  '
                    'Smaller values will regularize parameters towards zero. '
                    'Values must be greater than 0.'),
              default=1.)
@click.option('--output_prior',
              help=('Width of normal prior for input embedding.  '
                    'Smaller values will regularize parameters towards zero. '
                    'Values must be greater than 0.'),
              default=1.)
@click.option('--top-k',
              help=('Number of top hits to compare for cross-validation.'),
              default=50)
@click.option('--learning-rate',
              help=('Gradient descent decay rate.'),
              default=1e-1)
@click.option('--beta1',
              help=('Gradient decay rate for first Adam momentum estimates'),
              default=0.9)
@click.option('--beta2',
              help=('Gradient decay rate for second Adam momentum estimates'),
              default=0.95)
@click.option('--clipnorm',
              help=('Gradient clipping size.'),
              default=10.)
@click.option('--threads',
              help=('Number of threads to utilize.'),
              default=64)
@click.option('--checkpoint-interval',
              help=('Number of seconds before a storing a summary.'),
              default=1000)
@click.option('--summary-interval',
              help=('Number of seconds before a storing a summary.'),
              default=1000)
@click.option('--summary-dir', default='summarydir',
              help='Summary directory to save cross validation results.')
@click.option('--ranks-file',
              help='Ranks file containing microbe-metabolite rankings.')
def autoencoder(microbe_file, metabolite_file,
                metadata_file, training_column,
                num_testing_examples, min_feature_count,
                epochs, batch_size, latent_dim,
                input_prior, output_prior, top_k,
                learning_rate, beta1, beta2, clipnorm, threads,
                checkpoint_interval, summary_interval,
                summary_dir, ranks_file):

    microbes = load_table(microbe_file)
    metabolites = load_table(metabolite_file)

    if metadata_file is not None:
        metadata = pd.read_table(metadata_file, index_col=0)
    else:
        metadata = None

    res = split_tables(
        microbes, metabolites,
        metadata=metadata, training_column=training_column,
        num_test=num_testing_examples,
        min_samples=min_feature_count)

    (train_microbes_df, test_microbes_df,
     train_metabolites_df, test_metabolites_df) = res

    # filter out low abundance microbes
    microbe_ids = microbes.ids(axis='observation')
    metabolite_ids = metabolites.ids(axis='observation')

    params = []

    sname = 'latent_dim_' + str(latent_dim) + \
           '_input_prior_%.2f' % input_prior + \
           '_output_prior_%.2f' % output_prior + \
           '_beta1_%.2f' % beta1 + \
           '_beta2_%.2f' % beta2

    sname = os.path.join(summary_dir, sname)

    n, d1 = microbes.shape
    n, d2 = metabolites.shape

    train_microbes_coo = coo_matrix(train_microbes_df.values)
    test_microbes_coo = coo_matrix(test_microbes_df.values)

    config = tf.ConfigProto()
    config.intra_op_parallelism_threads = threads
    config.inter_op_parallelism_threads = threads
    with tf.Graph().as_default(), tf.Session(config=config) as session:
        model = Autoencoder(
            latent_dim=latent_dim,
            u_scale=input_prior, v_scale=output_prior,
            learning_rate = learning_rate,
            beta_1=beta1, beta_2=beta2,
            clipnorm=clipnorm, save_path=sname)
        model(session,
              train_microbes_coo, train_metabolites_df.values,
              test_microbes_coo, test_metabolites_df.values)

        loss, cv = model.fit(epoch=epochs, summary_interval=summary_interval,
                             checkpoint_interval=checkpoint_interval)

        U, V = model.U, model.V
        d1 = U.shape[0]

        U_ = np.hstack(
            (np.ones((model.U.shape[0], 1)),
             model.Ubias.reshape(-1, 1), U)
        )
        V_ = np.vstack(
            (model.Vbias.reshape(1, -1),
             np.ones((1, model.V.shape[1])), V)
        )

        ranks = pd.DataFrame(
            clr(softmax(np.hstack(
                (np.zeros((model.U.shape[0], 1)), U_ @ V_)))),
            index=train_microbes_df.columns,
            columns=train_metabolites_df.columns)

        np.savetxt(os.path.join(summary_dir, 'U.txt'), model.U)
        np.savetxt(os.path.join(summary_dir, 'V.txt'), model.V)
        np.savetxt(os.path.join(summary_dir, 'Ubias.txt'), model.Ubias)
        np.savetxt(os.path.join(summary_dir, 'Vbias.txt'), model.Vbias)

        params, rank_stats = cross_validation(
            model, test_microbes_df, test_metabolites_df, top_N=top_k)

        params.to_csv(os.path.join(summary_dir, 'model_results.csv'))
        rank_stats.to_csv(os.path.join(summary_dir, 'otu_cv_results.csv'))
        ranks.to_csv(ranks_file)


@rhapsody.command()
@click.option('--ranks-file',
              help='Ranks file containing microbe-metabolite rankings')
@click.option('--k-nearest-neighbors',
              help=('Number of nearest neighbors.'),
              default=3)
@click.option('--node-metadata',
              help='Node metadata for cytoscape.')
@click.option('--edge-metadata',
              help='Edge metadata for cytoscape.')
@click.option('--axis', default=0,
              help='Direction to draw edges. '
              'axis=0 guarantees that each column will have at least k edges'
              'axis=1 guarantees that each row will have at least k edges')
def network(ranks_file, k_nearest_neighbors, node_metadata, edge_metadata, axis):
    ranks = pd.read_csv(ranks_file, index_col=0).T
    if axis == 0:
        edges = rank_hits(ranks, k_nearest_neighbors)
    else:
        edges = rank_hits(ranks.T, k_nearest_neighbors)
    edges['edge_type'] = 'co_occur'
    src = list(edges['src'].value_counts().index)
    dest = list(edges['dest'].value_counts().index)
    edges = edges.set_index(['src'])
    edges[['edge_type', 'dest']].to_csv(edge_metadata, sep='\t', header=False)

    nodes = pd.DataFrame(columns=['id', 'node_type'])
    nodes['id'] = list(src) + list(dest)
    nodes['node_type'] = ['src'] * len(src) + ['dest'] * len(dest)
    nodes = nodes.set_index('id')
    nodes.to_csv(node_metadata, sep='\t')


if __name__ == '__main__':
    rhapsody()
