#!/usr/bin/env python
from __future__ import print_function

import os, yaml
import sys
import argparse
import numpy as np
import xbow
from xbow.instances import get_by_name, ConnectedInstance

def provision(args):
    """
    Run a script that configures each node in an xbow cluster.

    Each line in the script should be a shell command. All commands are
    run as root (i.e., 'sudo' is added to the start of each if needed).
    Each command is executed in parallel on each node in the cluster 
    (scheduler plus workers). If any command exits with a non-zero exit 
    status on any node, execution of the script stops and the error is 
    reported.
    Lines in the script file that begin with # are ignored..
    """
    cfg_file = os.path.join(xbow.XBOW_CONFIGDIR, "settings.yml")

    with open(cfg_file, 'r') as ymlfile:
        cfg = yaml.load(ymlfile)

    scheduler = get_by_name(cfg['scheduler_name'])
    if len(scheduler) == 0:
        raise ValueError('Error - cannot find the scheduler')
    elif len(scheduler) > 1:
        raise ValueError('Error - more than one scheduler found')
    workers = get_by_name(cfg['worker_pool_name'])
    if len(workers) == 0:
        print('Warning: no workers found')
    all_nodes = scheduler + workers
    all_cis = [ConnectedInstance(i) for i in all_nodes]
    with open(args.script, 'r') as f:
        for line in f:
            if len(line) > 0 and line[0] == '#':
                print(line[:-1])
            elif len(line) > 0 :
                command = line[:-1]
                if command.split()[0] != 'sudo':
                    command = 'sudo ' + command
                print(command + ' : ', end='', flush=True)
                result = exec_all(all_cis, command)
                status = np.all(np.array(result) == 0)
                if status:
                    print('OK')
                else:
                    print('FAILED')
                    for i in range(len(result)):
                        if result[i] != 0:
                            if i == 0:
                                print('Error on scheduler:')
                            else:
                                print('Error on worker {}'.format(i-1))
                            print(all_cis[i].output)
                    break
            else:
                status = False
                print(line[:-1], ' : ERROR')
                break

    return status

def exec_all(cis, command):
    '''
    Execute a command on a list of connected instances, in parallel.
    '''
    for c in cis[:-1]:
        c.exec_command(command, block=False)
    cis[-1].exec_command(command)
    for c in cis[:-1]:
        c.wait()
    result = [c.exit_status for c in cis]
    return result

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Provision an Xbow cluster via a script')
    parser.add_argument('script', help='name of script containing shell commands')
    args = parser.parse_args()
    exit_status = provision(args)
