#!/usr/bin/python
import atexit, errno, logging, os, signal, socket
import sqlite3, subprocess, sys, time, threading
from collections import deque
from OpenSSL import crypto
from re6st import ctl, db, plib, tunnel, utils, version
from re6st.registry import RegistryClient, RENEW_PERIOD
from re6st.utils import exit

class ReexecException(Exception):
    pass

def getConfig():
    parser = utils.ArgParser(fromfile_prefix_chars='@',
        description="Resilient virtual private network application.")
    _ = parser.add_argument
    _('-V', '--version', action='version', version=version.version)

    _('--ip', action='append', default=[],
        help="IP address advertised to other nodes. Special values:\n"
             "- upnp: redirect ports when UPnP device is found\n"
             "- any: ask peers our IP\n"
             " (default: like 'upnp' if miniupnpc is installed,\n"
             "  otherwise like 'any')")
    _('--registry', metavar='URL',
        help="Public HTTP URL of the registry, for bootstrapping.")
    _('-l', '--log', default='/var/log/re6stnet',
        help="Path to the directory used for log files:\n"
             "- re6stnet.log: log file of re6stnet itself\n"
             "- babeld.log: log file of router\n"
             "- <iface>.log: 1 file per spawned OpenVPN\n")
    _('-s', '--state', default='/var/lib/re6stnet',
        help="Path to re6stnet state directory:\n"
             "- peers.db: cache of peer addresses\n"
             "- babeld.state: see option -S of babeld\n")
    _('-v', '--verbose', default=1, type=int, metavar='LEVEL',
        help="Log level of re6stnet itself. 0 disables logging."
             " Use SIGUSR1 to reopen log."
             " See also --babel-verb and --verb for logs of spawned processes.")
    _('-i', '--interface', action='append', dest='iface_list', default=[],
        help="Extra interface for LAN discovery. Highly recommanded if there"
             " are other re6st node on the same network segment.")
    _('-I', '--main-interface', metavar='IFACE', default='lo',
        help="Set re6stnet IP on given interface. Any interface not used for"
             " tunnelling can be chosen.")
    _('--up', metavar='CMD',
        help="Shell command to run after successful initialization.")
    _('--daemon', action='append', metavar='CMD',
        help="Same as --up, but run in background: the command will be killed"
             " at exit (with a TERM signal, followed by KILL 5 seconds later"
             " if process is still alive).")
    _('--test', metavar='EXPR',
        help="Exit after configuration parsing. Status code is the"
             " result of the given Python expression. For example:\n"
             "  main_interface != 'eth0'")

    _ = parser.add_argument_group('routing').add_argument
    _('-B', dest='babel_args', metavar='ARG', action='append', default=[],
        help="Extra arguments to forward to Babel.")
    _('--babel-pidfile', metavar='PID', default='/var/run/re6st-babeld.pid',
        help="Specify a file to write our process id to"
             " (option -I of Babel).")
    _('--control-socket', metavar='CTL_SOCK', default=ctl.SOCK_PATH,
        help="Socket path to use for communication between re6stnet and babeld"
             " (option -R of Babel).")
    _('--hello', type=int, default=15,
        help="Hello interval in seconds, for both wired and wireless"
             " connections. OpenVPN ping-exit option is set to 4 times the"
             " hello interval. It takes between 3 and 4 times the"
             " hello interval for Babel to re-establish connection with a"
             " node for which the direct connection has been cut.")
    _('--table', type=int, default=42,
        help="Use given table id. Set 0 to use the main table, if you want to"
             " access internet via this network (in this case, make sure you"
             " don't already have a default route). Don't use this option with"
             " --gateway (main table is automatically used).")
    _('--gateway', action='store_true',
        help="Act as a gateway for this network (the default route will be"
             " exported). Do never use it if you don't know what it means.")

    _ = parser.add_argument_group('tunnelling').add_argument
    _('-O', dest='openvpn_args', metavar='ARG', action='append', default=[],
        help="Extra arguments to forward to both server and client OpenVPN"
             " subprocesses. Often used to configure verbosity.")
    _('--ovpnlog', action='store_true',
        help="Tell each OpenVPN subprocess to log to a dedicated file.")
    _('--encrypt', action='store_true',
        help='Specify that tunnels should be encrypted.')
    _('--pp', nargs=2, action='append', metavar=('PORT', 'PROTO'),
        help="Port and protocol to be announced to other peers, ordered by"
             " preference. For each protocol (udp, tcp, udp6, tcp6), start one"
             " openvpn server on the first given port."
             " (default: --pp 1194 udp --pp 1194 tcp)")
    _('--dh',
        help='File containing Diffie-Hellman parameters in .pem format')
    _('--ca', required=True, help=parser._ca_help)
    _('--cert', required=True,
        help="Local peer's signed certificate in .pem format."
             " Common name defines the allocated prefix in the network.")
    _('--key', required=True,
        help="Local peer's private key in .pem format.")
    _('--client-count', default=10, type=int,
        help="Number of client tunnels to set up.")
    _('--max-clients', type=int,
        help="Maximum number of accepted clients per OpenVPN server. (default:"
             " client-count * 2, which actually represents the average number"
             " of tunnels to other peers)")
    _('--tunnel-refresh', default=300, type=int,
        help="Interval in seconds between two tunnel refresh: the worst"
             " tunnel is closed if the number of client tunnels has reached"
             " its maximum number (client-count).")
    _('--remote-gateway', action='append', dest='gw_list',
        help="Force each tunnel to be created through one the given gateways,"
             " in a round-robin fashion.")
    _('--disable-proto', action='append',
        choices=('none', 'udp', 'tcp', 'udp6', 'tcp6'), default=['udp', 'udp6'],
        help="Do never try to create tunnels using given protocols."
             " 'none' has precedence over other options.")
    _('--client', metavar='HOST,PORT,PROTO[;...]',
        help="Do not run any OpenVPN server, but only 1 OpenVPN client,"
             " with specified remotes. Any other option not required in this"
             " mode is ignored (e.g. client-count, max-clients, etc.)")
    _('--neighbour', metavar='CN', action='append', default=[],
        help="List of peers that should be reachable directly, by creating"
             " tunnels if necesssary.")

    return parser.parse_args()

def maybe_renew(path, cert, info, renew):
    while True:
        next_renew = utils.notAfter(cert) - RENEW_PERIOD
        if time.time() < next_renew:
            return cert, next_renew
        try:
            pem = renew()
            if not pem or pem == crypto.dump_certificate(
                  crypto.FILETYPE_PEM, cert):
                exc_info = 0
                break
            cert = crypto.load_certificate(crypto.FILETYPE_PEM, pem)
        except Exception:
            exc_info = 1
            break
        new_path = path + '.new'
        with open(new_path, 'w') as f:
            f.write(pem)
        try:
          s = os.stat(path)
          os.chown(new_path, s.st_uid, s.st_gid)
        except OSError:
          pass
        os.rename(new_path, path)
        logging.info("%s renewed until %s UTC",
            info, time.asctime(time.gmtime(utils.notAfter(cert))))
    logging.error("%s not renewed. Will retry tomorrow.",
                  info, exc_info=exc_info)
    return cert, time.time() + 86400

def main():
    # Get arguments
    config = getConfig()
    with open(config.ca) as f:
        ca = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
    with open(config.cert) as f:
        cert = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
    prefix = utils.binFromSubnet(utils.subnetFromCert(cert))
    config.openvpn_args += (
        '--ca', config.ca,
        '--cert', config.cert,
        '--key', config.key)
    # TODO: verify certificates (should we moved to M2Crypto ?)

    if config.test:
        sys.exit(eval(config.test, None, config.__dict__))

    # Set logging
    utils.setupLog(config.verbose, os.path.join(config.log, 're6stnet.log'))

    logging.trace("Environment: %r", os.environ)
    logging.trace("Configuration: %r", config)
    utils.makedirs(config.state)
    db_path = os.path.join(config.state, 'peers.db')
    if config.ovpnlog:
        plib.ovpn_log = config.log

    exit.signal(0, signal.SIGINT, signal.SIGTERM)
    exit.signal(-1, signal.SIGHUP, signal.SIGUSR2)

    registry = RegistryClient(config.registry, config.key, ca)
    cert, next_renew = maybe_renew(config.cert, cert, "Certificate",
                                   lambda: registry.renewCertificate(prefix))
    ca, ca_renew = maybe_renew(config.ca, ca, "CA Certificate", registry.getCa)
    if next_renew > ca_renew:
        next_renew = ca_renew
    network = utils.networkFromCa(ca)

    if config.max_clients is None:
        config.max_clients = config.client_count * 2

    if 'none' in config.disable_proto:
        config.disable_proto = ()
    if not config.table:
        # Make sure we won't tunnel over re6st.
        config.disable_proto = tuple(set(('tcp6', 'udp6')).union(
            config.disable_proto))
    address = ()
    server_tunnels = {}
    forwarder = None
    if config.client:
        config.babel_args.append('re6stnet')
    elif config.max_clients:
        if config.pp:
            pp = [(int(port), proto) for port, proto in config.pp]
            for port, proto in pp:
                if proto in config.disable_proto:
                    sys.exit("error: conflicting options --disable-proto %s"
                             " and --pp %u %s" % (proto, port, proto))
        else:
            pp = [x for x in ((1194, 'udp'), (1194, 'tcp'))
                    if x[1] not in config.disable_proto]
        def ip_changed(ip):
            for family, proto_list in ((socket.AF_INET, ('tcp', 'udp')),
                                       (socket.AF_INET6, ('tcp6', 'udp6'))):
                try:
                    socket.inet_pton(family, ip)
                    break
                except socket.error:
                    pass
            else:
                family = None
            return family, [(ip, str(port), proto) for port, proto in pp
                            if not family or proto in proto_list]
        if config.gw_list:
          gw_list = deque(config.gw_list)
          def remote_gateway(dest):
            gw_list.rotate()
            return gw_list[0]
        else:
          remote_gateway = None
        if len(config.ip) > 1:
            if 'upnp' in config.ip or 'any' in config.ip:
                sys.exit("error: argument --ip can be given only once with"
                         " 'any' or 'upnp' value")
            logging.info("Multiple --ip passed: note that re6st does nothing to"
                " make sure that incoming paquets are replied via the correct"
                " gateway. So without manual network configuration, this can"
                " not be used to accept server connections from multiple"
                " gateways.")
        if 'upnp' in config.ip or not config.ip:
            logging.info('Attempting automatic configuration via UPnP...')
            try:
                from re6st.upnpigd import Forwarder
                forwarder = Forwarder('re6stnet openvpn server')
            except Exception, e:
                if config.ip:
                    raise
                logging.info("%s: assume we are not NATed", e)
            else:
                atexit.register(forwarder.clear)
                for port, proto in pp:
                    forwarder.addRule(port, proto)
                ip_changed = forwarder.checkExternalIp
                address = ip_changed(),
        elif 'any' not in config.ip:
            address = map(ip_changed, config.ip)
            ip_changed = None
        for x in pp:
            server_tunnels.setdefault('re6stnet-' + x[1], x)
    else:
        ip_changed = remote_gateway = None

    def call(cmd):
        logging.debug('%r', cmd)
        p = subprocess.Popen(cmd, stdout=subprocess.PIPE,
                                  stderr=subprocess.PIPE)
        stdout, stderr = p.communicate()
        if p.returncode:
            raise EnvironmentError("%r failed with error %u\n%s"
                                   % (' '.join(cmd), p.returncode, stderr))
        return stdout
    def required(arg):
        if not getattr(config, arg):
            sys.exit("error: argument --%s is required" % arg)
    def ip(object, *args):
        args = ['ip', '-6', object, 'add'] + list(args)
        call(args)
        args[3] = 'del'
        cleanup.append(lambda: subprocess.call(args))

    try:
        subnet = network + prefix
        my_ip = utils.ipFromBin(subnet, '1')
        my_subnet = '%s/%u' % (utils.ipFromBin(subnet), len(subnet))
        my_network = "%s/%u" % (utils.ipFromBin(network), len(network))
        os.environ['re6stnet_ip'] = my_ip
        os.environ['re6stnet_iface'] = config.main_interface
        os.environ['re6stnet_subnet'] = my_subnet
        os.environ['re6stnet_network'] = my_network
        my_ip += '/%s' % len(subnet)

        # Init db and tunnels
        tunnel_interfaces = server_tunnels.keys()
        timeout = 4 * config.hello
        cleanup = []
        if config.client_count and not config.client:
            required('registry')
            peer_db = db.PeerDB(db_path, registry, config.key, network, prefix)
            cleanup.append(lambda: peer_db.cacheMinimize(config.client_count))
            tunnel_manager = tunnel.TunnelManager(config.control_socket,
                peer_db, config.openvpn_args, timeout, config.tunnel_refresh,
                config.client_count, config.iface_list, network, prefix,
                address, ip_changed, config.encrypt, remote_gateway,
                config.disable_proto, config.neighbour)
            cleanup.append(tunnel_manager.sock.close)
            tunnel_interfaces += tunnel_manager.new_iface_list
            write_pipe = tunnel_manager.write_pipe
        else:
            tunnel_manager = write_pipe = None

        try:
            exit.acquire()
            # Source address selection is defined by RFC 6724, and in most
            # applications, it usually works  thanks to rule 5 (prefer outgoing
            # interface). But here, it rarely applies because we use several
            # interfaces to connect to a re6st network.
            # Rule 7 is little strange because it prefers temporary addresses
            # over IP with a longer matching prefix (rule 8, which is not even
            # mandatory).
            # So only rule 6 can make the difference, i.e. prefer same label.
            # The value of the label does not matter, except that it must be
            # different from ::/0's (normally equal to 1).
            # XXX: This does not work with extra interfaces that already have
            #      an public IP so Babel must be changed to set a source
            #      address on routes it installs.
            ip('addrlabel', 'prefix', my_network, 'label', '99')
            # prepare persistent interfaces
            if config.client:
                address_list = [x for x in utils.parse_address(config.client)
                                  if x[2] not in config.disable_proto]
                if not address_list:
                    sys.exit("error: --disable_proto option disables"
                             " all addresses given by --client")
                cleanup.append(plib.client('re6stnet',
                    address_list, config.encrypt, '--ping-restart',
                    str(timeout), *config.openvpn_args).stop)
            elif server_tunnels:
                required('dh')
                for iface, (port, proto) in server_tunnels.iteritems():
                    cleanup.append(plib.server(iface, config.max_clients,
                        config.dh, write_pipe, port, proto, config.encrypt,
                        '--ping-exit', str(timeout), *config.openvpn_args).stop)

            ip('addr', my_ip, 'dev', config.main_interface)
            if_rt = ['ip', '-6', 'route', 'del',
                     'fe80::/64', 'dev', config.main_interface]
            if config.main_interface == 'lo':
                # WKRD: Removed this useless route now, since the kernel does
                #       not even remove it on exit.
                subprocess.call(if_rt)
            if_rt[4] = my_subnet
            cleanup.append(lambda: subprocess.call(if_rt))
            x = [my_network]
            if config.gateway:
                config.table = 0
            elif config.table:
                x += 'table', str(config.table)
                try:
                    ip('rule', 'from', *x)
                except EnvironmentError:
                    logging.error("It seems that your kernel was compiled"
                        " without support for source address based routing"
                        " (CONFIG_IPV6_SUBTREES). Consider using --table=0"
                        " option if you can't change your kernel.")
                    raise
                ip('rule', 'to', *x)
                call(if_rt)
                if_rt += x[1:]
                call(if_rt[:3] + ['add', 'proto', 'static'] + if_rt[4:])
            else:
                def check_no_default_route():
                    for route in call(('ip', '-6', 'route', 'show',
                                        'default')).splitlines():
                        if ' proto 42 ' not in route:
                            sys.exit("Detected default route (%s)"
                                " whereas you specified --table=0."
                                " Fix your configuration." % route)
                check_no_default_route()
                def check_no_default_route_thread():
                    try:
                        while True:
                            time.sleep(60)
                            try:
                                check_no_default_route()
                            except OSError, e:
                                if e.errno != errno.ENOMEM:
                                    raise
                    except:
                        utils.log_exception()
                    finally:
                        exit.kill_main(1)
                t = threading.Thread(target=check_no_default_route_thread)
                t.daemon = True
                t.start()
            ip('route', 'unreachable', *x)

            config.babel_args += config.iface_list
            cleanup.append(plib.router(subnet, config.hello, config.table,
                os.path.join(config.log, 'babeld.log'),
                os.path.join(config.state, 'babeld.state'),
                config.babel_pidfile, tunnel_interfaces,
                config.control_socket,
                *config.babel_args).stop)
            if config.up:
                exit.release()
                r = os.system(config.up)
                if r:
                    sys.exit(r)
                exit.acquire()
            for cmd in config.daemon or ():
                cleanup.insert(-1, utils.Popen(cmd, shell=True).stop)

            # main loop
            select_list = [forwarder.select] if forwarder else []
            if tunnel_manager:
                select_list.append(tunnel_manager.select)
                cleanup[-1:-1] = (tunnel_manager.delInterfaces,
                                  tunnel_manager.killAll)
            exit.release()
            def renew():
                raise ReexecException("Restart to renew certificate")
            select_list.append(utils.select)
            while True:
                args = {}, {}, [(next_renew, renew)]
                for s in select_list:
                    s(*args)
        finally:
            # XXX: We have a possible race condition if a signal is handled at
            #      the beginning of this clause, just before the following line.
            exit.acquire(0) # inhibit signals
            while cleanup:
                try:
                    cleanup.pop()()
                except:
                    pass
            exit.release()
    except sqlite3.Error:
        logging.exception("Restarting with empty cache")
        os.rename(db_path, db_path + '.bak')
    except ReexecException, e:
        logging.info(e)
    except Exception:
        utils.log_exception()
        sys.exit(1)
    try:
        sys.exitfunc()
    finally:
        os.execvp(sys.argv[0], sys.argv)

if __name__ == "__main__":
    main()
