#!/bin/env python
#
#
# CLI for starting and running Spark standalone clusters on HPC resources
#
#

from __future__ import print_function
import click
from sparkhpc import sparkjob
import subprocess
import os
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('sparkhpc')

home = os.path.expanduser('~')

@click.group()
@click.option('--scheduler', type=click.Choice(['lsf']), default='lsf', help='Which scheduler to use')
@click.pass_context
def cli(ctx, scheduler):
    ctx.obj['SJ'] = sparkjob._sparkjob_factory(scheduler)


@cli.command()
@click.argument('ncores')
@click.option('--walltime', default="00:30", help="Walltime in HH:MM format")
@click.option('--jobname', default='sparkcluster', help='Name to use for the job')
@click.option('--template', default=None, help='Job template path')
@click.option('--memory', default='2000', envvar='SPARK_WORKER_MEMORY',
              help='Memory for each worker in MB')
@click.option('--spark-home', default=os.path.join(home,'spark'), envvar='SPARK_HOME', 
              help='Location of the Spark distribution')
@click.option('--wait', default=False, is_flag=True, help='Wait until the job starts')
@click.pass_context
def start(ctx, ncores, walltime, jobname, template, memory, spark_home, wait):
    """Start the spark cluster as a batch job"""
    
    SJ = ctx.obj['SJ']

    sj = SJ(ncores=ncores, walltime=walltime, jobname=jobname, template=template, memory=memory, spark_home=spark_home)
    
    if wait: 
        logger.info(' Waiting for job to start - ctrl-c to stop')
        sj.wait_to_start()
    else:
        sj.submit()
    

@cli.command()
@click.pass_context
def info(ctx):
    """Get info about currently running clusters"""
    SJ = ctx.obj['SJ']
    sjs = SJ.current_clusters()

    if len(sjs)>0:
        sjs[0].show_clusters()
    else: 
        logger.info(' No spark clusters running')

@cli.command()
@click.argument('clusterid')
@click.pass_context
def stop(ctx, clusterid):
    """Kill a currently running cluster ('all' to kill all clusters)"""
    SJ = ctx.obj['SJ']
    sjs = SJ.current_clusters()

    if clusterid == 'all': 
        if len(sjs) == 0: 
            logger.info(' No clusters running')
        for sj in sjs: 
            sj.stop()
    elif int(clusterid) < len(sjs): 
        sjs[int(clusterid)].stop()
    else: 
        raise RuntimeError('Cluster %s does not exist'%clusterid)


@cli.command()
@click.option('--memory', default='2000', help='Memory for each worker in MB')
@click.pass_context
def launch(ctx, memory):
    """Launch the Spark master and workers within a current job context"""
    sparkjob.start_cluster(memory)

if __name__ == "__main__":
    cli(obj={})
