#!python

"""
cdn Command Line Interface
author: Kevin Menden, DZNE Tübingen

This is the main file for executing the cdn program.
"""

# imports
import click
import scaden
from scaden.scaden_main import training, prediction, processing

@click.group()
@click.version_option('1.0.0')
def cli():
    pass


"""
Training mode
"""
@cli.command()
@click.argument(
    'data_path',
    type = click.Path(exists=True),
    required = True,
    metavar = '<training data>'
)
@click.option(
    '--train_datasets',
    default = '',
    help = 'Datasets used for training. Uses all by default.'
)
@click.option(
    '--model_dir',
    default = "./",
    help = 'Path to store the model in'
)
@click.option(
    '--batch_size',
    default = 128,
    help = 'Batch size to use for training. Default: 128'
)
@click.option(
    '--learning_rate',
    default = 0.0001,
    help = 'Learning rate used for training. Default: 0.0001'
)
@click.option(
    '--steps',
    default = 20000,
    help = 'Number of training steps'
)
def train(data_path, train_datasets, model_dir, batch_size, learning_rate, steps):
    """ Train a cdn model """
    training(data_path=data_path,
                      train_datasets=train_datasets,
                      model_dir=model_dir,
                      batch_size=batch_size,
                      learning_rate=learning_rate,
                      num_steps=steps)


"""
Predictin mode
"""
@cli.command()
@click.argument(
    'data_path',
    type = click.Path(exists=True),
    required = True,
    metavar = '<prediction data>'
)
@click.option(
    '--model_dir',
    default = "./",
    help = 'Path to trained model'
)
@click.option(
    '--outname',
    default = "cdn_predictions.txt",
    help = 'Name of predictions file.'
)
def predict(data_path, model_dir, outname):
    """ cdn prediction using a trained model """
    prediction(model_dir=model_dir,
                        data_path=data_path,
                        out_name=outname)



"""
Processing mode
"""
@cli.command()
@click.argument(
    'data_path',
    type = click.Path(exists=True),
    required = True,
    metavar = '<training dataset to be processed>'
)
@click.argument(
    'prediction_data',
    type = click.Path(exists=True),
    required = True,
    metavar = '<data for prediction>'
)
@click.option(
    '--processed_path',
    default = "processed.h5ad",
    help = 'Path of processed file. Must end with .h5ad'
)
def process(data_path, prediction_data, processed_path):
    """ Process a dataset for training """
    processing(data_path=prediction_data,
                        training_data=data_path,
                        processed_path=processed_path)


if __name__ == '__main__':


    cli()
