#!/usr/bin/env python3

import re
import os
import sys
import time
import pprint
import select
import argparse
import threading

import subprocess    as sp

import radical.utils as ru


# ------------------------------------------------------------------------------
#
# This script is a wrapper around an actual executable.  It's purpose is,
# depening on the mode of operation, to either watch the executable's stdout for
# some pattern, or to watch an output file created by the executable for that
# pattern.  If the pattern is found, the matching data is extracted and sent to
# *this* script's stdout.  If the executable returns a non-zero exit code, or if
# a timeout is reached, the script will also return a non-zero exit code.  If
# The script returns a zero exit code but a pattern is provided but not found,
# the script will also return a non-zero exit code.
#
# The script is called with the following arguments:
#
#   -m mode   : data stream to search through, either `stdout` (default),
#               `stderr`, or a file name
#   -p pattern: regex pattern to search for
#   -o output : file name to store the executable's stdout
#   -e error  : file name to store the executable's stderr
#   -t timeout: time in seconds after which the script will terminate if pattern
#               is not found
#   --        : marks the end of the script's arguments and the start of the
#               executable's name and arguments


class ServiceWrapper(object):

    def __init__(self):

        parser = argparse.ArgumentParser(description='')

        parser.add_argument('-c', '--command',  dest='command')
        parser.add_argument('-r', '--registry', dest='reg_url')
        parser.add_argument('-u', '--uid',      dest='uid')
        parser.add_argument('-m', '--mode',     dest='mode')
        parser.add_argument('-p', '--pattern',  dest='pattern')
        parser.add_argument('-o', '--output',   dest='output')
        parser.add_argument('-e', '--error',    dest='error')
        parser.add_argument('-v', '--verbose',  dest='verbose', action='store_true')
        parser.add_argument('-t', '--timeout',  dest='timeout')
        parser.add_argument('-P', '--ppid',     dest='ppid')

        args = parser.parse_args()
        self._command = args.command or None
        self._reg_url = args.reg_url or os.environ.get('RP_REGISTRY_ADDRESS')
        self._uid     = args.uid     or os.environ.get('RP_TASK_ID')
        self._mode    = args.mode    or 'stdout'
        self._pattern = args.pattern or None
        self._output  = args.output  or '%s.stdout' % self._uid
        self._error   = args.error   or '%s.stderr' % self._uid
        self._verbose = args.verbose or False
        self._timeout = args.timeout or 0
        self._ppid    = args.ppid    or None
        self._pwatch  = None

        self._timeout = int(self._timeout) if self._timeout else 0

        if not self._command: raise ValueError('no command given')
        if not self._uid    : raise ValueError('no uid given')
        if not self._reg_url: raise ValueError('RP_REGISTRY_ADDRESS not set')

        self._sbox = os.environ.get('RP_TASK_SANDBOX', './')
        self._log  = ru.Logger('radical.pilot.service_wrapper',
                               targets='%s/%s.log' % (self._sbox, self._uid))
        self._log.debug('starting service')

        if self._ppid:
            self._pwatch = ru.PWatcher(uid='%s.pw' % self._uid,
                                       action=ru.PWatcher.KILLALL,
                                       log=self._log)
            self._pwatch.watch(int(self._ppid))

        self._stop_evt      = threading.Event()
        self._timeout_evt   = threading.Event()
        self._match         = None
        self._result_sent   = False

        self._proc = None
        self._ctrl = 'control_pubsub'
        self._rank = int(os.environ.get('RP_RANK', 0))

        self.connect_pubsub()
        self.start_service()
        self.start_to_watcher()
        self.join_service()


    # --------------------------------------------------------------------------
    #
    def connect_pubsub(self):

        # get control pubsub from registry
        self._reg = ru.zmq.RegistryClient(url=self._reg_url)
        addr_pub  = self._reg.get('bridges.%s' % self._ctrl)['addr_pub']
        self._pub = ru.zmq.Publisher(channel=self._ctrl, url=addr_pub)

        self._log.debug('reg_url: %s', self._reg_url)
        self._log.debug('pub_url: %s', addr_pub)

        # let zmq settle
        time.sleep(0.1)


    # --------------------------------------------------------------------------
    #
    def send_result(self, err_msg=None):
        '''
        send the result to the agent
        '''

        self._log.info('send result: %s [%s]', self._match, err_msg)

        if self._result_sent:
            self._log.error('duplicated?')
            return

        msg = {'cmd': 'service_info',
               'arg': {'uid'  : self._uid,
                       'error': err_msg,
                       'info' : self._match}}

        self._log.debug(pprint.pformat(msg))
        self._pub.put(self._ctrl, msg)

        self._result_sent = True


    # --------------------------------------------------------------------------
    #
    def stop(self, reason=None, ret=1):
        '''
        kill the service, stop all watchers, notify agent, and exit
        '''

        self._log.debug('stop: %s', reason)
        if self._proc:
            self._proc.kill()

        self._stop_evt.set()
        self._timeout_evt.set()

        self.send_result(reason)

        sys.exit(ret)


    # --------------------------------------------------------------------------
    #
    def start_to_watcher(self):

        if self._timeout is None:
            return

        # ----------------------------------------------------------------------
        def watch_to():

            if not self._timeout:
                return

            self._log.debug('start to watcher')
            start = time.time()

            while not self._timeout_evt.is_set():

                if time.time() - start > self._timeout:
                    self.stop('timeout')

                time.sleep(0.5)

            self._log.debug('stop to watcher')
        # ----------------------------------------------------------------------

        watcher = threading.Thread(target=watch_to)
        watcher.daemon = True
        watcher.start()


    # ------------------------------------------------------------------------------
    #
    def start_service(self):

        # most of the code below is only relevant for rank 0 and if a pattern is
        # given.

        # run the command
        cmd     = '%s >%s 2>%s' % (self._command, self._output, self._error)
        err_msg = ''

        try:
            self._log.debug('executing: %s', cmd)
            self._proc = sp.Popen(cmd, shell=True)

            self._pwatch.watch(self._proc.pid)
            self._log.debug('proc_pid: %s', self._proc.pid)

            if not self._pattern:
                self._log.info('no pattern to watch')
                self.send_result()
                return

            if self._rank > 0:
                self._log.debug('rank %s > 0, no pattern to watch', self._rank)
                return

            # open the stream to watch
            fname = None
            if   self._mode == 'stdout'        : fname = self._output
            elif self._mode == 'stderr'        : fname = self._error
            elif self._mode.startswith('file:'): fname = self._mode[5:]

            if not fname:
                if self._mode:
                    self.stop('unknown mode %s' % self._mode)
                return

            fstream = None
            while self._proc.poll() is None:
                try:
                    fstream = ru.ru_open(fname, 'r')
                    break
                except:
                    time.sleep(0.1)
                    continue

            self._log.debug('fstream [%s]', fstream)
            self._log.debug('pattern [%s]', self._pattern)

            start = time.time()
            regex = re.compile('^.*(%s).*$' % self._pattern)
            data  = ''

            while self._proc.poll() is None:

                if self._timeout and time.time() - start > self._timeout:
                    self.stop('startup timeout')

                r, _, _ = select.select([fstream], [], [], 0.1)
                if fstream in r:

                    data += str(os.read(fstream.fileno(), 1).decode())

                    if not data.endswith('\n'):
                        continue

                    line = data[:-1]
                    data = ''

                    if self._verbose:
                        self._log.debug('line: %s', line)

                    match = regex.match(line)
                    if match:
                        self._match = match.groups()[-1]
                        self.send_result()
                        break

        except Exception as e:
            err_msg = repr(e)
            self.stop('command failed: %s' % err_msg)


    # --------------------------------------------------------------------------
    #
    def join_service(self):
        '''
        now just run until `cmd` completes and return its exit code
        '''

        while self._proc.poll() is None:
          # self._log.debug('wait: %s', self._proc.poll())
            time.sleep(1.0)

        ret = self._proc.poll()

        if ret > 0:
            self.stop('command failed', ret)

        else:
            self.stop(ret=0)


# ------------------------------------------------------------------------------
#
if __name__ == '__main__':

    ServiceWrapper()


# ------------------------------------------------------------------------------

