#!/usr/bin/python

"""

 Swabber is a daemon for management of IP bans. The bans are
accepted over a 0mq interface (or interfaces) and are expired after a
period of time.

I wrote this code in a simpler time, and now I am quite ashamed
of a lot of it. It all needs a rewrite.

"""

__author__ = "nosmo@nosmo.me"

from swabber import BanCleaner
from swabber import BanFetcher

import yaml

import logging
import logging.handlers
import optparse
import os
import signal
import sys
import time
import threading
import traceback

BACKENDS = ["iptables", "hostsfile", "iptables_cmd"]

DEFAULT_CONFIG = {
    "bantime": 120,
    "bindstrings": ["tcp://127.0.0.1:22620"],
    "polltime": 60,
    "interface": "eth+",
    "backend": "iptables",
    "whitelist": [],
    "logpath": "/var/log/swabber.log"
}

def daemon_setup():
    # First fork
    try:
        pid = os.fork()
        if pid > 0:
            # exit first parent
            sys.exit(0)
    except OSError, e:
        sys.stderr.write("fork #1 failed: %d (%s)\n" % (e.errno, e.strerror))
        sys.exit(1)

    # Don't hang onto any files accidentally
    os.chdir("/")
    # Decouple from our environment
    os.setsid()
    os.umask(0)

    # Second fork
    try:
        pid = os.fork()
        if pid > 0:
            # exit if second parent
            sys.exit(0)
    except OSError, e:
        sys.stderr.write("Second fork failed: %d (%s)\n" % (e.errno, e.strerror))
        sys.exit(1)

    # redirect standard file descriptors
    sys.stdout.flush()
    sys.stderr.flush()
    si = file("/dev/null", 'r')
    so = file("/dev/null", 'a+')
    #se = file("/dev/null", 'a+', 0)
    os.dup2(si.fileno(), sys.stdin.fileno())
    os.dup2(so.fileno(), sys.stdout.fileno())
    #os.dup2(se.fileno(), sys.stderr.fileno())

    # write pidfile
    #atexit.register(self.delpid)
    #pid = str(os.getpid())
    #file(self.pidfile,'w+').write("%s\n" % pid)

def get_config(configpath):
    '''Load the configuration file and update the defaults dictionary
    with any information included in it.

    Returns the configuration as a dictionary.

    '''

    config = DEFAULT_CONFIG

    with open(configpath) as config_h:
        config.update(yaml.safe_load(config_h.read()))

    if config["backend"] not in BACKENDS:
        raise ValueError("%s is not in backends: %s",
                         config["backend"],
                         ", ".join(BACKENDS))
    return config

class Swabberd(object):

    def __init__(self, config, configpath):
        self.config = config
        self.configpath = configpath
        self.banner = None
        self.cleaner = None

        self.run_threads()

    def run_threads(self):

        '''Start the individual threads for cleaner and fetcher. Catches
        SIGTERM.

        Args:
         config: a dictionary of the loaded YAML config
         configpath: a string containing the path to the config

        Returns nothing at the moment.

        '''

        iptables_lock = threading.Lock()

        # To control execution of cleaner
        self.running = True

        if self.config["bantime"] != 0:
            self.cleaner = BanCleaner(self.config["bantime"], self.config["backend"],
                                      iptables_lock, self.config["interface"])
        self.banner = BanFetcher(self.config["bindstrings"],
                                 self.config["interface"], self.config["backend"],
                                 self.config["whitelist"], iptables_lock)

        def handle_signal(signum, frame):
            if signum == signal.SIGTERM:
                self.banner.stop_running()
                if self.config["bantime"]:
                    self.running = False
                logging.warning("Closing on SIGTERM")
            elif signum == signal.SIGHUP:
                config = get_config(self.configpath)
                self.banner.stop_running()
                self.banner = BanFetcher(self.config["bindstrings"],
                                         self.config["interface"], self.config["backend"],
                                         self.config["whitelist"], iptables_lock)
                if self.cleaner:
                    self.cleaner = BanCleaner(self.config["bantime"], self.config["backend"],
                                              iptables_lock, self.config["interface"])
                logging.info("Reloaded config")

        signal.signal(signal.SIGTERM, handle_signal)
        signal.signal(signal.SIGHUP, handle_signal)

        try:
            self.banner.start()
            logging.warning("Started running banner")
        except Exception as e:
            print "Unhandled exception %s" % e
            logging.error("Swabber exiting on unhandled exception %s!", str(e))
            self.banner.stop_running()

        while self.running:
            if self.config["bantime"] == 0:
                # We only have one thread so we'll have to just do this dumb sleep here.
                time.sleep(0.1)
            else:
                try:
                    self.cleaner.clean_bans(self.config["interface"])
                    time.sleep(self.config["polltime"])
                except Exception as exc:
                    logging.error("Uncaught exception in cleaner! %s", str(exc))
                    traceback.print_exc()

def main():

    parser = optparse.OptionParser()
    parser.add_option("-v", "--verbose", dest="verbose",
                      help="Be verbose in output, don't daemonise",
                      action="store_true")
    parser.add_option("-F", "--force", dest="forcerun",
                      help="Try to run when not root",
                      action="store_true")

    parser.add_option("-c", "--config",
                      action="store", dest="configpath",
                      default="/etc/swabber.yaml",
                      help="alternate path for configuration file")

    (options, args) = parser.parse_args()
    config = get_config(options.configpath)

    if options.verbose:
        mainlogger = logging.getLogger()

        logging.basicConfig(level=logging.DEBUG)
        ch = logging.StreamHandler(sys.stdout)
        ch.setLevel(logging.DEBUG)
        ch.setFormatter(logging.Formatter(
            'swabber (%(process)d): %(levelname)s %(message)s'))
        mainlogger.addHandler(ch)
    else:
        # Set up logging
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)
        logfile_handler = logging.handlers.WatchedFileHandler(config["logpath"])
        logfile_handler.setFormatter(logging.Formatter(
            'swabber (%(process)d) %(asctime)s: %(levelname)s %(message)s'))
        logger.addHandler(logfile_handler)

    if os.getuid() != 0 and not options.forcerun:
        sys.stderr.write("Not running as I need root access - use -F to force run\n")
        sys.exit(1)

    if not os.path.isfile(options.configpath):
        sys.stderr.write("Couldn't load config file %s!\n" % options.configpath)
        sys.exit(1)

    if not options.verbose:

        daemon_setup()

        with open("/var/run/swabberd.pid", "w") as mypid:
            mypid.write(str(os.getpid()))

        logging.info("Starting swabber in daemon mode")

    s = Swabberd(config, options.configpath)


if __name__ == "__main__":
    main()
