#!/usr/bin/env python

import os
import sys
import time

import multiprocessing as mp
import threading       as mt

import radical.utils   as ru

# wtf
import queue

# FIXME: the func executor may need a small bootstrapper

pwd = sys.argv[1]


# ------------------------------------------------------------------------------
# activate virtenv if needed
ve = None
if len(sys.argv) > 2:
    ve = sys.argv[2]

if ve and ve not in ['', 'None', None]:

    activate = "%s/bin/activate_this.py" % ve

  # execfile(activate, dict(__file__=activate))
    exec(open(activate).read(), dict(__file__=activate))


# ------------------------------------------------------------------------------
#
class Executor(object):
    '''
    This executor is running as an RP task and owns a complete node.  On each
    core of that node, it spawns a worker process to execute function calls.
    Communication to those processes is establshed via two mp.Queue instances,
    one for feeding call requests to the worker processes, and one to collect
    results from their execution

    Once the workers are prepared, the Executor will listens on an task level
    ZMQ channel for incoming call requests, which are then proxied to the
    workers as described above.  This happens in a separate thread.  Another
    thread is spawned to inversely collect the results as described above and to
    proxy them to an outgoing ZMQ channel.  The Executor main thread will listen
    on a 3rd ZMQ channel for control messages, and specifically for termination
    commands.
    '''

    # --------------------------------------------------------------------------
    #
    def __init__(self, n_workers=None):

        self._nw   = n_workers
        self._uid  = os.environ['RP_FUNCS_ID']
        self._log  = ru.Logger(self._uid,   ns='radical.pilot', path=pwd)
        self._prof = ru.Profiler(self._uid, ns='radical.pilot', path=pwd)
        self._cfg  = ru.read_json('%s/%s.cfg' % (pwd, self._uid))

        self._initialize()


    # --------------------------------------------------------------------------
    #
    def _initialize(self):
        '''
        set up processes, threads and communication channels
        '''

        self._prof.prof('init_start', uid=self._uid)

        addr_req = self._cfg.get('req_get')
        addr_res = self._cfg.get('res_put')

        self._log.debug('req get addr: %s', addr_req)
        self._log.debug('res put addr: %s', addr_res)

        assert(addr_req)
        assert(addr_res)

        # connect to
        #
        #   - the queue which feeds us tasks
        #   - the queue were we send completed tasks
        #   - the command queue (for termination)
        #
        self._zmq_req = ru.zmq.Getter(channel='funcs_req_queue', url=addr_req)
        self._zmq_res = ru.zmq.Putter(channel='funcs_res_queue', url=addr_res)
      # self._zmq_ctl = ru.zmq.Getter(channel='CTL', url=addr['CTL_GET'])

        # use mp.Queue instances to proxy tasks to the worker processes
        self._mpq_work    = mp.Queue()
        self._mpq_result  = mp.Queue()

        # signal for thread termination
        self._term = mt.Event()

        # start threads to feed / drain the workers
        self._t_get_work    = mt.Thread(target=self._get_work)
        self._t_get_results = mt.Thread(target=self._get_results)

        self._t_get_work.daemon    = True
        self._t_get_results.daemon = True

        self._t_get_work.start()
        self._t_get_results.start()

        # start one worker per core
        if not self._nw:
            self._nw = mp.cpu_count()

        self._log.debug('#workers: %d', self._nw)

        self._workers = list()
        for i in range(self._nw):
            wid  = '%s.%03d' % (self._uid, i)
            proc = mp.Process(target=self._work, args=[self._uid, wid])
            proc.daemon = True
            proc.start()
            self._workers.append(proc)

        self._prof.prof('init_stop', uid=self._uid)


    # --------------------------------------------------------------------------
    #
    def run(self):
        '''
        executor main loop: initialize all connections, processes, threads, then
        listen on the command channel for things to do (like, terminate).
        '''

        while True:

          # msgs = self._zmq_ctl.get_nowait(100)
            msgs = None
            time.sleep(1)

            if not msgs:
                continue

            for msg in msgs:

                self._prof.prof('cmd', uid=self._uid, msg=msg['cmd'])

                if msg['cmd'] == 'term':

                    # kill worker processes
                    for worker in self._workers:
                        worker.terminate()

                    sys.exit(0)

                else:
                    self._log.error('unknown command %s', msg)


    # --------------------------------------------------------------------------
    #
    def _get_work(self):
        '''
        thread feeding tasks pulled from the ZMQ work queue to worker processes
        '''

        # FIXME: This drains the qork queue with no regard of load balancing.
        #        For example, the first <n_cores> tasks may stall this executer
        #        for a long time, but new tasks are pulled nonetheless, even if
        #        other executors are not stalling and could execute them timely.
        #        We should at most fill a cache of limited size.

        while not self._term.is_set():

            tasks = self._zmq_req.get_nowait(1000)

            if tasks:

                self._log.debug('got %d tasks', len(tasks))

                # send task individually to load balance workers
                for task in tasks:
                    self._mpq_work.put(task)


    # --------------------------------------------------------------------------
    #
    def _get_results(self):
        '''
        thread feeding back results from to workers to the result ZMQ queue
        '''

        while not self._term.is_set():

            # we always pull *individual* tasks from the result queue
            try:
                task = self._mpq_result.get(block=True, timeout=0.1)

            except queue.Empty:
                continue

            if task:
                self._zmq_res.put(task)


    # --------------------------------------------------------------------------
    #
    def _work(self, uid, wid):
        '''
        work loop for worker processes: pull a task from the work queue,
        run it, push the result onto the result queue
        '''

        self._prof.prof('work_start', comp=wid, uid=uid)

        while True:

            try:
                task = self._mpq_work.get(block=True, timeout=0.1)

            except queue.Empty:
                continue

          # import pprint
          # pprint.pprint(task)

            tid   = task['uid']
            descr = task['description']
            exe   = descr['executable']
            args  = descr.get('arguments', list())
            pres  = descr.get('pre_exec',  list())
            cmd   = '%s(%s)' % (exe, ','.join(args))

            self._prof.prof('task_get', comp=wid, uid=tid)
          # self._log.debug('get %s: %s', tid, cmd)

            try:
                for pre in pres:
                    if pre.split()[0] == 'import':
                        for mod in pre.split()[1:]:
                            if mod not in globals():
                                globals()[mod] = ru.import_module(mod)
                                locals() [mod] = globals()[mod]

                task['stdout'] = eval(cmd)
                task['stderr'] = None
                task['state']  = 'DONE'

            except Exception as e:
                task['stdout'] = None
                task['stderr'] = str(e)
                task['state']  = 'FAILED'

          # self._log.debug('put %s: %s', tid, str(task['res']))
            self._prof.prof('task_put', comp=wid, uid=tid)

            task['wid'] = wid
            self._mpq_result.put(task)


# ------------------------------------------------------------------------------
#
if __name__ == '__main__':

    executor = Executor()
    executor.run()


# ------------------------------------------------------------------------------

