#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import print_function, division, absolute_import
import os
import sys
import click
import datetime
import six
import shlex
import which
import platform
import time
import contextlib
import StringIO
import traceback
import logging

if os.name == 'posix' and sys.version_info[0] < 3:
    import subprocess32 as subprocess
else:
    import subprocess

DUMMY = False


@contextlib.contextmanager
def stdoutIO(stream='stdout'):
    old = getattr(sys, stream)
    strio = StringIO.StringIO()
    setattr(sys, stream, strio)
    yield strio
    setattr(sys, stream, old)


def instant_log_file(filename, msg, mode="w"):
    now = str(datetime.datetime.now())
    if DUMMY:
        print("{}\n{}\n{}".format(filename, now, msg))
    else:
        logging.debug(msg)
        with open(filename, mode) as f_out:
            f_out.write(now + '\n' + msg)


node_id = int(os.getenv('SLURM_NODEID', 0))
local_id = int(os.getenv('SLURM_LOCALID', 0))
cpus_on_node = int(os.getenv('SLURM_CPUS_ON_NODE', 1))
nnodes = int(os.getenv('SLURM_JOB_NUM_NODES', 1))
jobid = os.getenv('SLURM_JOBID', 0)

logging.basicConfig(filename="wos_{}_{:02d}_{:02d}.log".format(jobid, node_id, local_id),
                    filemode='w',
                    format='%(asctime)s %(message)s',
                    level=logging.DEBUG)

msg = """{date}
node: {node}
executable: {execu}
node_id: {node_id}
local_id: {local_id}
""".format(date=datetime.datetime.now(),
           node=platform.node(),
           execu=sys.executable,
           node_id=node_id,
           local_id=local_id)
print(msg)
logging.info(msg)


def Popen(*args, **kwargs):
    if kwargs['shell']:
        # turn shell commands into non-shell in order to enable terminate
        cmd = shlex.split(args[0])
        args = (cmd, ) + args[1:]
        kwargs['shell'] = False
    if DUMMY:
        print("DUMMY Popen(*{}, **{})".format(args, kwargs))
    else:
        return subprocess.Popen(*args, **kwargs)


def check_ipcontroller(profile_dir, timeout=600, interval=1, silent=False):
    from ipyparallel import Client

    if DUMMY:
        print('DUMMY check_ipcontroller')
        time.sleep(1)
        return None

    maxtime = time.time() + timeout
    while True:
        time.sleep(interval)
        try:
            logging.debug('Trying Client(profile_dir={}'.format(profile_dir))
            if not silent:
                print('Trying Client(profile_dir={}'.format(profile_dir))
            Client(profile_dir=profile_dir)
        except Exception as e:
            if not silent:
                print(e)
                logging.debug(str(e))
            logging.info('sleep({})'.format(interval))
            time.sleep(interval)
        else:
            break
        if time.time() > maxtime:
            logging.error('time out in checl_ipcontroller')
            raise Exception('Timeout')


def check_dscheduler(dscheduler, timeout=600, interval=1):

    from distributed import Executor

    if DUMMY:
        print('DUMMY check_ipcontroller')
        time.sleep(1)
        return None

    maxtime = time.time() + timeout
    while True:
        try:
            logging.debug('Trying Executor({})'.format(dscheduler))
            Executor(dscheduler)
        except Exception as e:
            logging.debug(str(e))
            logging.info('sleep({})'.format(interval))
            time.sleep(interval)
        else:
            break
        if time.time() > maxtime:
            logging.error('time out in checl_ipcontroller')
            raise Exception('Timeout')


def start_controller(executor='ipyparallel',
                     profile_dir=None,
                     bin_path=None,
                     monitor_period=None,
                     port=None):

    if bin_path is None:
        bin_path = sys.exec_prefix
    if profile_dir is None:
        profile_dir = os.getcwd()

    if not monitor_period:
        monitor_option = ''
    else:
        monitor_option = " --HeartMonitor.period={:d}".format(monitor_period)

    if not port:
        port_option = ''
    else:
        port_option = " --port {port:d}".format(port)

    if executor == 'ipyparallel':
        cmd_template = ("{bin_path}/ipcontroller"
                        " {monitor_option}"
                        " --ip=*"
                        " --nodb"
                        " --profile-dir={profile_dir}")
    elif executor == 'distributed':
        cmd_template = ("{bin_path}/dscheduler"
                        " --no-bokeh"
                        " --no-show"
                        " {port_option}"
                        " --profile-dir={profile_dir}")
    else:
        raise ValueError('Unknow executor type: {}'.format(executor))

    p_cmd = cmd_template.format(bin_path=bin_path,
                                profile_dir=profile_dir,
                                monitor_option=monitor_option,
                                port_option=port_option)

    print(p_cmd)
    p = Popen(p_cmd, shell=True)

    return p


def start_worker(executor='ipyparallel',
                 profile_dir=None,
                 dscheduler=None,
                 nthreads=1,
                 nprocs=1,
                 bin_path=None):

    if bin_path is None:
        bin_path = sys.exec_prefix
    if profile_dir is None:
        profile_dir = os.getcwd()

    if executor == 'ipyparallel':
        cmd_template = ("{bin_path}/ipengine --profile-dir={profile_dir}")
    elif executor == 'distributed':
        cmd_template = (
            "{bin_path}/dworker --nthreads {nthreads} --nprocs {nprocs} {dscheduler}")
    else:
        raise ValueError('Unknow executor type: {}'.format(executor))

    p_cmd = cmd_template.format(profile_dir=profile_dir,
                                dscheduler=dscheduler,
                                nthreads=nthreads,
                                nprocs=nprocs,
                                bin_path=bin_path)

    print(p_cmd)
    p = Popen(p_cmd, shell=True)

    return p


@click.command(help="Launch wowp on a slurm cluster")
@click.option('--executor',
              '-e',
              default='ipyparallel',
              type=click.Choice(['ipyparallel', 'distributed']),
              show_default=True,
              help="Executor type")
@click.option(
    '--pcluster',
    '-p',
    default=cpus_on_node,
    type=int,
    show_default=True,
    help="Number of processes per single cluter, default to one cluster per node")
@click.option('--profiles-dir',
              'profiles_dir',
              default=None,
              type=click.Path(exists=False),
              help="Profile base dir, defaults to current directory")
@click.option('--monitor-period',
              'monitor_period',
              default=0,
              type=int,
              help="Monitor period in ms")
@click.option('--job-file',
              'job_file',
              help='Job status file, processes exit when deleted')
@click.argument('script', nargs=1, required=False, type=click.Path(exists=True))
@click.pass_context
def main(ctx, executor, pcluster, profiles_dir, monitor_period, job_file, script):

    # there seems to be a confusion
    # (https://groups.google.com/forum/#!topic/slurm-devel/3tLPgShGM9A)
    all_hostnames = None
    for varname in ('SLURM_JOB_NODELIST', 'SLURM_NODELIST'):
        if os.getenv(varname):
            with subprocess.Popen('scontrol show hostname ${}'.format(varname),
                                  shell=True,
                                  stdout=subprocess.PIPE) as p:
                all_hostnames = [h.strip() for h in p.stdout.readlines()]
                if all_hostnames:
                    break

    if not all_hostnames:
        print('no hostnames')
        all_hostnames = ['localhost']

    print('all_hostnames: {}'.format(all_hostnames))

    if profiles_dir is None:
        profiles_dir = os.getcwd()
    profiles_dir = os.path.abspath(profiles_dir)
    if not os.path.exists(profiles_dir):
        os.makedirs(profiles_dir)
    print('profiles_dir: {}'.format(profiles_dir))

    nprocs = nnodes * cpus_on_node
    print('nprocs: {}'.format(nprocs))

    bin_path = os.path.dirname(which.which('ipcontroller'))
    print('Binary path: {}'.format(bin_path))

    nodes_per_cluster = max(pcluster // nprocs, 1)
    print('Nodes per cluster: {}'.format(nodes_per_cluster))

    p_controllers = []
    p_engines = []

    controller_node_ids = list(range(0, nnodes, nodes_per_cluster))

    # node_id = int(os.getenv('SLURM_NODEID'))
    # local_id = int(os.getenv('SLURM_LOCALID'))
    # print("node_id: {}".format(node_id))
    # print("local_id: {}".format(local_id))

    msg = str(locals())
    # instant_log_file("locals_{:02d}_{:02d}".format(int(node_id), int(local_id)), msg)
    logging.debug(msg)

    def get_profile_dir(n):
        return os.path.join(profiles_dir, '_profile_{}_{}'.format(jobid, n))

    if node_id in controller_node_ids and local_id == 0:
        profile_dir = get_profile_dir(node_id)
        # start ipcontroller
        p_controllers.append({'proc': start_controller(profile_dir=profile_dir,
                                                       bin_path=bin_path),
                              'profile_dir': profile_dir})

    else:
        controller_node_id = controller_node_ids[node_id // nodes_per_cluster]
        profile_dir = get_profile_dir(controller_node_id)
        # wait for controller
        check_ipcontroller(profile_dir, interval=5)
        # start engine
        p_engines.append(start_worker(executor=executor,
                                      profile_dir=profile_dir,
                                      bin_path=bin_path))

    print('--- launched ---')
    print(p_controllers)
    print(p_engines)
    print('--- ---')
    # run the script on the first node
    # TODO the number of workers on this node should be minus 1
    if node_id == controller_node_ids[0] and local_id == 0:
        all_profile_dirs = [get_profile_dir(n) for n in controller_node_ids]
        # wait for all controllers
        msg = "wait for controllers before script"
        instant_log_file("wowp_log_wait", msg)
        for profile_dir in all_profile_dirs:
            check_ipcontroller(profile_dir, interval=5)

        # TODO wait for engines

        if script:
            # run the script
            msg = 'running {}'.format(script)
            print(msg)
            instant_log_file("wowp_log_wait", msg)
            # script_globals = {'WOWP_IPY_PROFILE_DIRS': all_profile_dirs}
            # exec_args = (compile(open(script).read(), script, 'exec'), script_globals)
            os.environ['WOWP_IPY_PROFILE_DIRS'] = os.pathsep.join(all_profile_dirs)
            os.environ['WOWP_EXECUTOR'] = 'ipyparallel'
            exec_args = (compile(open(script).read(), script, 'exec'), )

            # TODO multiprocessing gets stuck in scheduler init
            # p_script = multiprocessing.Process(
            #     target=six.exec_,
            #     args=exec_args)
            # p_script.start()
            # p_script.join()
            # p_script.terminate()

            stderr = []
            stdout = []
            return_status = ''

            with stdoutIO() as stdo, stdoutIO('stderr') as stde:
                try:
                    six.exec_(*exec_args)
                except Exception:
                    traceback.print_exc(file=stde)
                    return_status = traceback.format_exc()
                finally:
                    stdout.append(stdo.getvalue())
                    stderr.append(stde.getvalue())

                    msg = "stdo:\n{}\nstde:\n{}\n".format('\n'.join(stdout),
                                                          '\n'.join(stderr))
                    instant_log_file("script_std_{:02d}_{:02d}".format(
                        int(node_id), int(local_id)), msg)

            # try:
            #     six.exec_(*exec_args)
            # finally:

                for p in p_engines:
                    try:
                        p.terminate()
                    except:
                        print('could not kill an engine')
                for p_controller in p_controllers:
                    try:
                        p_controller['proc'].terminate()
                    except:
                        print('could not kill a controller')
            print('\n--done--')
            if job_file:
                logging.info('remove {} and exit'.format(job_file))
                os.remove(job_file)
            sys.exit(0)
        else:
            # print scheduler info
            msg = (("WOW:-P scheduler available:\n"
                    "FuturesScheduler('ipyparallel', "
                    "executor_kwargs={{'profile_dirs': {profile_dirs}}})").format(
                        profile_dirs=all_profile_dirs))
            print(msg)
            instant_log_file("wowp_log_wait", msg)

    if DUMMY:
        print('dummy -> exit')
        sys.exit(0)

    # wait for sigterm
    print('wait for sigterm')
    try:
        logging.info('Entering infinite loop, wait for sigterm')
        while True:
            time.sleep(0.1)
            if job_file and not os.path.exists(job_file):
                logging.info('Job file {} not found --> exit'.format(job_file))
                break
    except KeyboardInterrupt:
        print('\nTerminating processes')
    finally:
        for p in p_engines:
            try:
                p.terminate()
            except:
                pass
        for p_controller in p_controllers:
            try:
                p_controller['proc'].terminate()
            except:
                pass
        print('\n--done--')


if __name__ == '__main__':
    main()
