#!/usr/bin/env python
"""
Script used to slurp (seed, eat or monitor) logs. You can also use it as a
daemon monitor.

See --help for usage.
"""
from __future__ import with_statement
import errno
import fcntl
import functools
import logging
import logging.handlers
from optparse import OptionParser
import os
import signal
import select
import sys
import tempfile
import threading
import time

from daemon import DaemonContext
from lockfile.pidlockfile import PIDLockFile
try:
    from setproctitle import setproctitle
except ImportError:
    setproctitle = None

import slurp


logger = logging.getLogger('slurp')


def configure_log(level):
    logger = logging.getLogger('')
    logger.setLevel(level)


def configure_stderr():
    logger = logging.getLogger('')
    handler = logging.StreamHandler(sys.stderr)
    fmt = logging.Formatter(
        '%(process)d : %(asctime)s : %(levelname)s : %(name)s : %(message)s')
    handler.setFormatter(fmt)
    logger.addHandler(handler)


def configure_syslog():
    logger = logging.getLogger('')
    handler = logging.handlers.SysLogHandler('/dev/log')
    fmt = logging.Formatter('slurp[%(process)d]: %(message)s')
    handler.setFormatter(fmt)
    logger.addHandler(handler)


class Monitor(object):
    """
    Used to manage monitor.
    """

    class Reload(Exception):
        """
        Raised to reload the configuration.
        """
        pass


    class Terminate(Exception):
        """
        Raised to gracefully terminate.
        """
        pass

    def __init__(self, create_conf):
        self.create_conf = create_conf
        self._conf = None

    @property
    def conf(self):
        if not self._conf:
            logger.info('creating configuration')
            self._conf = self.create_conf()
        return self._conf

    @property
    def signal_map(self):
        return {
            signal.SIGHUP: self.handle_hup,
            signal.SIGTERM: self.handle_term,
            }

    def handle_term(self, signum, frame):
        logger.info('terminating')
        raise self.Terminate()

    def handle_hup(self, signum, frame):
        self.reload()
        raise self.Reload()

    def run(self, paths):
        while True:
            try:
                slurp.monitor(paths, self.conf)
                break
            except Monitor.Reload:
                continue
            except Monitor.Terminate:
                break

    def reload(self, signum, frame):
        logger.info('reloading')
        self._conf = None


SIGNALS = dict(map(
    lambda x: (getattr(signal, 'SIG%s' % x), x),
    'HUP QUIT INT TERM USR1 USR2 WINCH CHLD'.split()
    ))


class MonitorWorker(Monitor):

    class MasterCheckThread(threading.Thread):

        def __init__(self, parent_pid):
            super(MonitorWorker.MasterCheckThread, self).__init__()
            self.parent_pid = parent_pid
            self.check_freq = 1.0
            self.daemon = True

        def run(self):
            while True:
                if self.parent_pid != os.getppid():
                    logger.warning('parent %s changed to %s, terminating',
                        self.parent_pid, os.getppid())
                    os.kill(os.getpid(), signal.SIGTERM)
                    break
                time.sleep(self.check_freq)

    def __init__(self, create_conf, groups, conf):
        super(MonitorWorker, self).__init__(create_conf)
        self._conf = conf
        self.groups = groups

    @property
    def conf(self):
        if not self._conf:
            logger.info('creating configuration')
            self._conf = self.create_conf()
            self._conf.filter_consumers(self.groups)
        return self._conf

    def run(self, paths):
        self.MasterCheckThread(os.getppid()).start()
        return super(MonitorWorker, self).run(paths)


class MonitorMaster(object):
    """
    Used to manage forked monitors.

    See https://github.com/benoitc/gunicorn/blob/master/gunicorn/arbiter.py.
    """

    class Terminate(Exception):
        pass


    class Worker(object):

        def __init__(self, name, pid=None, fail_count=0):
            self.name = name
            self.pid = pid
            self.fail_count = fail_count

        def kill(self, sig):
            logger.info('sending signal "%s" to %s worker %s',
                SIGNALS[sig], self.name, self.pid)
            try:
                os.kill(self.pid, sig)
            except OSError, e:
                if e.errno == errno.ESRCH:
                    self.pid = None
                raise

    def __init__(self, create_conf):
        self.create_conf = create_conf
        self.conf = None
        self.max_fail_count = 3
        self.max_signal_queue = 5
        self.stop_timeout = 10
        self.reload_timeout = 10
        self.signals = []
        self.workers = {}
        self.pipe_r, self.pipe_w = None, None

    @property
    def signal_map(self):
        return {
            signal.SIGHUP: self.on_signal,
            signal.SIGTERM: self.on_signal,
            }

    def on_signal(self, signum, frame):
        logger.debug('signal num - "%s", frame - %s', signum, frame)
        if len(self.signals) < self.max_signal_queue:
            self.signals.append(signum)
            self.wakeup()
        else:
            logger.warning(
                'signal queue %s > %s dropping signal num - "%s", frame - %s',
                len(self.signals), self.max_signal_queue, signum, frame)

    def handle_term(self):
        raise self.Terminate()

    def handle_hup(self):
        self.reload()

    def run(self, paths):
        # title
        if setproctitle:
            setproctitle('slurp: master')

        # conf
        logger.info('creating configuration')
        self.conf = self.create_conf()

        # pipe for waking up during sleep (e.g. when we get a signal)
        self.pipe_r, self.pipe_w = os.pipe()
        for fd in [self.pipe_r, self.pipe_w]:
            flags = fcntl.fcntl(fd, fcntl.F_GETFD)
            flags |= fcntl.FD_CLOEXEC
            fcntl.fcntl(fd, fcntl.F_SETFD, flags)
            flags = fcntl.fcntl(fd, fcntl.F_GETFL) | os.O_NONBLOCK

        # loop
        try:
            while True:
                self.manage_workers(paths)
                sig = self.signals.pop(0) if self.signals else None
                if not sig:
                    self.sleep()
                    continue
                signal_handler = self.getattr(self, 'handle_' + sig, None)
                if not signal_handler:
                    logger.warning('ignoring unhandled signal "%s"',
                        SIGNALS[sig])
                    continue
                logger.info('handling signal "%s"', SIGNALS[sig])
                signal_handler()
        except Exception, ex:
            logger.exception(ex)
            raise
        finally:
            self.stop()

    def sleep(self):
        try:
            ready = select.select([self.pipe_r], [], [], 1.0)
            if not ready[0]:
                return
            while os.read(self.pipe_r, 1):
                pass
        except select.error, e:
            if e[0] not in [errno.EAGAIN, errno.EINTR]:
                raise
        except OSError, e:
            if e.errno not in [errno.EAGAIN, errno.EINTR]:
                raise
        except KeyboardInterrupt:
            sys.exit()

    def wakeup(self):
        try:
            os.write(self.pipe_w, '.')
        except IOError, e:
            if e.errno not in [errno.EAGAIN, errno.EINTR]:
                raise

    def stop(self):
        count = 0
        expires = time.time() + self.stop_timeout
        while time.time() < expires:
            count = self.kill_workers(signal.SIGTERM)
            if count == 0:
                break
            time.sleep(0.1)
            self.reap_workers()
        if count != 0:
            self.kill_workers(signal.SIGKILL)

    def manage_workers(self, paths):
        # fill in missing workers
        for group in self.conf.get_consumer_groups():
            if group not in self.workers:
                self.workers[group] = self.Worker(name=group)

        # kill obsolete workers
        groups = self.conf.get_consumer_groups()
        obsolete = set(self.workers.keys()).difference(groups)
        for name in obsolete:
            worker = self.workers[name]
            if worker.pid is None:
                self.workers.pop(worker.name)
            else:
                if worker.fail_count > self.max_fail_count:
                    self._kill_worker(worker, signal.SIGKILL)
                else:
                    self._kill_worker(worker, signal.SIGTERM)

        # reap dead workers
        self.reap_workers()

        # spawn workers
        for worker in self.workers.itervalues():
            if (worker.pid is None and
                worker.fail_count < self.max_fail_count):
                self.spawn_worker(worker, paths)

    def spawn_worker(self, worker, paths):
        # fork
        pid = os.fork()
        if pid:
            worker.pid = pid
            return

        # clear signals
        for sig in SIGNALS.iterkeys():
            signal.signal(sig, signal.SIG_DFL)

        # title
        if setproctitle:
            setproctitle('slurp: worker[{}]'.format(worker.name))

        # loop
        try:
            self.conf.filter_consumers([worker.name])
            monitor = MonitorWorker(self.create_conf, [worker.name], self.conf)
            for sig, handler in monitor.signal_map.iteritems():
                signal.signal(sig, handler)
            monitor.run(paths)
        except Exception, ex:
            logger.exception(ex)
            sys.exit(1)
        sys.exit(0)

    def reap_workers(self):
        try:
            while True:
                worker_pid, status = os.waitpid(-1, os.WNOHANG)
                if not worker_pid:
                    break
                exit_code = status >> 8
                worker = [
                    worker for worker in self.workers.itervalues()
                    if worker.pid == worker_pid
                    ]
                if not worker:
                    logger.warning('unknown worker %s', worker_pid)
                    continue
                worker = worker[0]
                if exit_code != 0:
                    logger.warning('%s worker %s failed with exit code %s',
                        worker.name, worker.pid, exit_code)
                    worker.fail_count += 1
                    if worker.fail_count >= self.max_fail_count:
                        logger.error(
                            '%s worker fail count %s >= max fail count %s',
                            worker.name, worker.fail_count, self.max_fail_count)
                worker.pid = None
        except OSError, e:
            if e.errno == errno.ECHILD:
                pass

    def kill_workers(self, sig):
        count = 0
        logger.info('sending signal "%s" to all workers', SIGNALS[sig])
        for worker in self.workers.itervalues():
            if worker.pid is not None:
                count += 1
                worker.kill(sig)
        return count

    def reload(self):
        logger.info('reloading')
        expires = time.time() + self.reload_timeout
        while self.workers and time.time() < expires:
            self.kill_workers(signal.SIGTERM)
            time.sleep(0.1)
            self.reap_workers()
        self.kill_workers(signal.SIGKILL)
        for worker in self.workers:
            worker.fail_count = 0
        logger.info('recreating configuration')
        self.conf = self.create_conf()


def main():
    opt_parser = OptionParser(
        version=slurp.__version__,
        usage="""
%prog s|seed path-1 .. path-n [options]
%prog m|monitor path-1 .. path-n [options]
%prog e|eat path-1 .. path-n [options]\
""")
    opt_parser.add_option(
        '-s', '--state-path', default=tempfile.gettempdir())
    opt_parser.add_option(
        '-c', '--consumer', dest='consumers', action='append', default=[])
    opt_parser.add_option(
        '-l', '--log-level', choices=['debug', 'info', 'warn', 'error'],
        default='warn')
    opt_parser.add_option(
        '--enable-syslog', action='store_true', default=False)
    opt_parser.add_option(
        '--disable-stderrlog', action='store_true', default=False)
    opt_parser.add_option(
        '-d', '--daemonize', action='store_true', default=False)
    opt_parser.add_option(
        '-f', '--forking', action='store_true', default=False)
    opt_parser.add_option(
        '--disable-locking', action='store_true', default=False)
    opt_parser.add_option(
        '--lock-timeout', type='int', default=None)
    opt_parser.add_option(
        '--disable-tracking', action='store_true', default=False)
    opt_parser.add_option(
        '--pid-file', default=None)
    opt_parser.add_option(
        '--sink', choices=['print', 'null'], default=None)
    opt_parser.add_option(
        '--batch-size', type='int', default=None)
    opt_parser.add_option(
        '--disable-backfill', action='store_true', default=False)

    (opts, args) = opt_parser.parse_args()

    if not args:
        raise Exception(opt_parser.get_usage())
    cmd = args[0]
    if cmd not in ('s', 'seed', 'm', 'monitor', 'e', 'eat'):
        raise Exception(opt_parser.get_usage())
    paths = args[1:]

    log_level = getattr(logging, opts.log_level.upper())
    configure_log(log_level)
    if not opts.disable_stderrlog:
        configure_stderr()

    consumer_paths = [
        os.path.abspath(os.path.expanduser(os.path.expandvars(consumer_path)))
        for consumer_path in opts.consumers]

    sink = None
    if opts.sink == 'null':
        sink = slurp.null_sink
    elif opts.sink == 'print':
        sink = slurp.print_sink

    # NOTE: intentionally deferred
    create_conf = functools.partial(
        slurp.Conf,
        opts.state_path,
        consumer_paths,
        not opts.disable_locking,
        opts.lock_timeout,
        not opts.disable_tracking,
        sink,
        opts.batch_size,
        opts.disable_backfill)

    if cmd in ('s', 'seed'):
        if opts.enable_syslog:  # NOTE: intentionally deferred
            configure_syslog(log_level)
        slurp.seed(paths or sys.stdin, create_conf())
    elif cmd in ('m', 'monitor'):
        monitor = (MonitorMaster if opts.forking else Monitor)(create_conf)
        if opts.daemonize:
            if opts.pid_file:
                pid_file = PIDLockFile(opts.pid_file)
            else:
                pid_file = None
            with DaemonContext(
                    pidfile=pid_file,
                    files_preserve=[],
                    signal_map=monitor.signal_map):
                if opts.enable_syslog:  # NOTE: intentionally deferred
                    configure_syslog()
                try:
                    monitor.run(paths)
                except Exception, ex:
                    logger.exception(ex)
                    raise
        else:
            if opts.enable_syslog:  # NOTE: intentionally deferred
                configure_syslog()
            for sig in monitor.signal_map.iterkeys():
                signal.signal(sig, signal.SIG_DFL)
            monitor.run(paths)
    elif cmd in ('e', 'eat'):
        if opts.enable_syslog:  # NOTE: intentionally deferred
            configure_syslog()
        slurp.eat(paths or sys.stdin, create_conf())


if __name__ == '__main__':
    main()
