#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division, print_function
import argparse, curses, errno, os, random, select
import signal, socket, subprocess, sys, threading, time
from contextlib import contextmanager
from datetime import datetime
from functools import partial
from multiprocessing import Lock, RawArray
from struct import Struct
from netfilterqueue import NetfilterQueue
import gevent.socket # preload for subprocesses
from neo.client.Storage import Storage
from neo.lib import logging, util
from neo.lib.connector import SocketConnector
from neo.lib.debug import PdbSocket
from neo.lib.node import Node
from neo.lib.protocol import NodeTypes
from neo.lib.util import timeStringFromTID, p64, u64
from neo.storage.app import DATABASE_MANAGER_DICT, \
    Application as StorageApplication
from neo.tests import getTempDirectory
from neo.tests.ConflictFree import ConflictFreeLog
from neo.tests.functional import AlreadyStopped, NEOCluster, Process
from neo.tests.stress import StressApplication
from transaction import begin as transaction_begin
from ZODB import DB, POSException

INET = {
    socket.AF_INET:  ('ip',  socket.IPPROTO_IP, socket.IP_TOS),
    socket.AF_INET6: ('ip6', socket.IPPROTO_IPV6, socket.IPV6_TCLASS),
}

NFT_TEMPLATE = """\
    table %s %s {
        chain mangle {
            type filter hook input priority -150
            policy accept
            %s dscp 1 tcp flags & (fin|syn|rst|ack) != syn jump nfqueue
        }
        chain nfqueue {
            %s
        }
        chain filter {
            type filter hook input priority 0
            policy accept
            meta l4proto tcp %s dscp 1 mark 1 counter reject with tcp reset
        }
    }"""

SocketConnector.KEEPALIVE = 5, 1, 1

def child_coverage(self):
    # XXX: The dance to collect coverage results just before killing
    #      subprocesses does not work for processes that may run code that
    #      is not interruptible with Python code (e.g. Lock.acquire).
    #      For nodes with a single epoll loop, this is usually fine.
    #      On the other side, coverage support is broken for clients,
    #      like here: we just do some cleanup for the assertion in __del__
    r = self._coverage_fd
    if r is not None:
        os.close(r)
        del self._coverage_fd
Process.child_coverage = child_coverage

def setDSCP(connection, dscp):
    connector = connection.getConnector()
    _, sol, opt = INET[connector.af_type]
    connector.socket.setsockopt(sol, opt, dscp << 2)

def dscpPatch(dscp):
    Node_setConnection = Node.setConnection
    Node.dscp = dscp
    def setConnection(self, connection, force=None):
        if self.dscp and self.getType() == NodeTypes.STORAGE:
            setDSCP(connection, 1)
        return Node_setConnection(self, connection, force)
    Node.setConnection = setConnection

class Client(Process):

    _fmt = '!I200s'
    prev_count = 0

    def __init__(self, command, thread_count, **kw):
        super(Client, self).__init__(command)
        self.config = kw
        self.count = RawArray('I', thread_count)
        self.thread_count = thread_count

    def run(self):
        from neo.lib.threaded_app import registerLiveDebugger
        registerLiveDebugger() # for on_log
        dscpPatch(0)
        self._dscp_lock = threading.Lock()
        storage = Storage(**self.config)
        db = DB(storage=storage)
        try:
            if self.thread_count == 1:
                self.worker(db)
            else:
                r, w = os.pipe()
                try:
                    for i in xrange(self.thread_count):
                        t = threading.Thread(target=self.worker,
                            args=(db, i, w), name='worker-%s' % i)
                        t.daemon = 1
                        t.start()
                    while 1:
                        try:
                            os.read(r, 1)
                            break
                        except OSError, e:
                            if e.errno != errno.EINTR:
                              raise
                finally:
                    os.close(r)
        finally:
            db.close()

    def worker(self, db, i=0, stop=None):
        try:
            nm = db.storage.app.nm
            conn = db.open()
            r = conn.root()
            count = self.count
            name = self.command
            if self.thread_count > 1:
                name += ':%s' % i
            j = 0
            k = None
            logs = r.values()
            pack = Struct(self._fmt).pack
            while 1:
                txn = transaction_begin()
                try:
                    data = pack(j, name)
                    for log in random.sample(logs, 2):
                        log.append(data)
                    txn.note(name)
                    self.setDSCP(nm, 1)
                    try:
                        txn.commit()
                    finally:
                        self.setDSCP(nm, -1)
                except (
                    POSException.StorageError,  # XXX: 'already connected' error
                    POSException.ConflictError, # XXX: same but during conflict resolution
                    ), e:
                    if 'unexpected packet:' in str(e):
                        raise
                    if j != k:
                        logging.exception('j = %s', j)
                        k = j
                    txn.abort()
                    continue
                j += 1
                count[i] = j
        finally:
            if stop is not None:
                try:
                    os.write(stop, '\0')
                except OSError:
                    pass

    def setDSCP(self, nm, dscp):
        with self._dscp_lock:
            prev = Node.dscp
            dscp += prev
            Node.dscp = dscp
            if dscp and prev:
                return
            for node in nm.getStorageList():
                try:
                    setDSCP(node.getConnection(), dscp)
                except (AttributeError, AssertionError,
                        # XXX: EBADF due to race condition
                        socket.error):
                    pass

    @classmethod
    def check(cls, r):
        nodes = {}
        hosts = []
        buckets = [0, 0]
        item_list = []
        unpack = Struct(cls._fmt).unpack
        def decode(item):
            i, host = unpack(item)
            return i, host.rstrip('\0')
        for log in r.values():
            bucket = log._next
            if bucket is None:
                bucket = log
                buckets[:] = bucket._p_estimated_size, 1
            while 1:
                for item in bucket._log:
                    i, host = decode(item)
                    try:
                        node = nodes[host]
                    except KeyError:
                        node = nodes[host] = len(nodes)
                        hosts.append(host)
                    item_list.append((i, node))
                if bucket is log:
                    break
                buckets[0] += bucket._p_estimated_size
                buckets[1] += 1
                bucket = bucket._next
        item_list.sort()
        nodes = [0] * len(nodes)
        for i, node in item_list:
            j = nodes[node] // 2
            if i != j:
                #import code; code.interact(banner="", local=locals())
                sys.exit('node: %s, expected: %s, stored: %s'
                         % (hosts[node], j, i))
            nodes[node] += 1
        for node, host in sorted(enumerate(hosts), key=lambda x: x[1]):
            print('%s\t%s' % (nodes[node], host))
        print('average bucket size: %f' % (buckets[0] / buckets[1]))
        print('target bucket size:', log._bucket_size)
        print('number of full buckets:', buckets[1])

    @property
    def logfile(self):
        return self.config['logfile']


class NFQueue(Process):

    def __init__(self, queue):
        super(NFQueue, self).__init__('nfqueue_%i' % queue)
        self.lock = l = Lock(); l.acquire()
        self.queue = queue

    def run(self):
        acquire = self.lock.acquire
        delay = self.delay
        nfqueue = NetfilterQueue()
        if delay:
            from gevent import sleep, socket, spawn
            from random import random
            def callback(packet):
                if acquire(0): packet.set_mark(1)
                else: sleep(random() * delay)
                packet.accept()
            callback = partial(spawn, callback)
        else:
            def callback(packet):
                if acquire(0): packet.set_mark(1)
                packet.accept()
        nfqueue.bind(self.queue, callback)
        try:
            if delay:
                s = socket.fromfd(nfqueue.get_fd(),
                    socket.AF_UNIX, socket.SOCK_STREAM)
                try:
                    nfqueue.run_socket(s)
                finally:
                    s.close()
            else:
                while 1:
                    nfqueue.run() # returns on signal (e.g. SIGWINCH)
        finally:
            nfqueue.unbind()


class Alarm(threading.Thread):

    __interrupt = BaseException()

    def __init__(self, signal, timeout):
        super(Alarm, self).__init__()
        self.__signal = signal
        self.__timeout = timeout

    def __enter__(self):
        self.__r, self.__w = os.pipe()
        self.__prev = signal.signal(self.__signal, self.__raise)
        self.start()

    def __exit__(self, t, v, tb):
        try:
            try:
                os.close(self.__w)
                self.join()
            finally:
                os.close(self.__r)
                signal.signal(self.__signal, self.__prev)
            return v is self.__interrupt
        except BaseException as e:
            if e is not self.__interrupt:
                raise

    def __raise(self, sig, frame):
        raise self.__interrupt

    def run(self):
        if not select.select((self.__r,), (), (), self.__timeout)[0]:
            os.kill(os.getpid(), self.__signal)


class NEOCluster(NEOCluster):

    def _newProcess(self, node_type, logfile=None, port=None, **kw):
        super(NEOCluster, self)._newProcess(node_type, logfile,
            port or self.port_allocator.allocate(
                self.address_type, self.local_ip),
            **kw)


class Application(StressApplication):

    _blocking = None

    def __init__(self, client_count, thread_count, restart_ratio, logrotate,
                 *args, **kw):
        self.client_count = client_count
        self.thread_count = thread_count
        self.logrotate = logrotate
        self.restart_ratio = restart_ratio
        self.cluster = cluster = NEOCluster(*args, **kw)
        # Make the firewall also affect connections between storage nodes.
        StorageApplication__init__ = StorageApplication.__init__
        def __init__(self, config):
            dscpPatch(1)
            StorageApplication__init__(self, config)
        StorageApplication.__init__  = __init__

        super(Application, self).__init__(cluster.SSL,
            util.parseMasterList(cluster.master_nodes))
        self._nft_family = INET[cluster.address_type][0]
        self._nft_table = 'stress_%s' % os.getpid()
        self._blocked = []
        n = kw['replicas']
        self._fault_count = len(kw['db_list']) * n // (1 + n)

    @property
    def name(self):
        return self.cluster.cluster_name

    def run(self):
        super(Application, self).run()
        try:
            with self.db() as r:
                Client.check(r)
        finally:
            self.cluster.stop()

    @contextmanager
    def db(self):
        cluster = self.cluster
        cluster.start()
        db, conn = cluster.getZODBConnection()
        try:
            yield conn.root()
        finally:
            db.close()

    def startCluster(self):
        with self.db() as r:
            txn = transaction_begin()
            for i in xrange(2 * self.client_count * self.thread_count):
                r[i] = ConflictFreeLog()
            txn.commit()
        cluster = self.cluster
        process_list = cluster.process_dict[NFQueue] = []
        nft_family = self._nft_family
        queue = []
        for _, (ip, port), nid, _, _ in sorted(cluster.getStorageList(),
                                               key=lambda x: x[2]):
            queue.append(
                "%s daddr %s tcp dport %s counter queue num %s bypass"
                % (nft_family, ip, port, nid))
            p = NFQueue(nid)
            process_list.append(p)
            p.start()
        ruleset = NFT_TEMPLATE % (nft_family, self._nft_table,
            nft_family, '\n            '.join(queue), nft_family)
        p = subprocess.Popen(('nft', '-f', '-'), stdin=subprocess.PIPE,
            stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
        err = p.communicate(ruleset)[0].rstrip()
        if p.poll():
            sys.exit("Failed to apply the following ruleset:\n%s\n%s"
                % (ruleset, err))
        process_list = cluster.process_dict[Client] = []
        config = cluster.getClientConfig()
        self.started = time.time()
        for i in xrange(self.client_count):
            name = 'client_%i' % i
            p = Client(name, self.thread_count,
                logfile=os.path.join(cluster.temp_dir, name + '.log'),
                **config)
            process_list.append(p)
            p.start()
        if self.logrotate:
            t = threading.Thread(target=self._logrotate_thread)
            t.daemon = 1
            t.start()

    def stopCluster(self, wait=None):
        self.restart_lock.acquire()
        self._cleanFirewall()
        process_dict = self.cluster.process_dict
        if wait:
            # Give time to flush logs before SIGKILL.
            wait += 5 - time.time()
            if wait > 0:
                with Alarm(signal.SIGUSR1, wait):
                    for x in Client, NodeTypes.STORAGE:
                        for x in process_dict[x]:
                            x.wait()
        self.cluster.stop()
        try:
            del process_dict[NFQueue], process_dict[Client]
        except KeyError:
            pass

    def _logrotate_thread(self):
        try:
            import zstd
        except ImportError:
            import gzip, shutil
            zstd = None
        compress = []
        rotated = {}
        t = time.time()
        while 1:
            t += self.logrotate
            x = t - time.time()
            if x > 0:
                time.sleep(x)
            x = datetime.utcnow().strftime('-%Y%m%d%H%M%S.log')
            for p, process_list in self.cluster.process_dict.iteritems():
                if p is not NFQueue:
                    for p in process_list:
                        log = p.logfile
                        if os.path.exists(log):
                            y = rotated.get(log)
                            if y:
                                compress.append(y)
                            y = log[:-4] + x
                            os.rename(log, y)
                            rotated[log] = y
                            try:
                                p.kill(signal.SIGRTMIN+1)
                            except AlreadyStopped:
                                pass
            for log in compress:
                if zstd:
                    with open(log, 'rb') as src:
                        x = zstd.compress(src.read())
                    y = log + '.zst'
                    with open(y, 'wb') as dst:
                        dst.write(x)
                else:
                    y = log + '.gz'
                    with open(log, 'rb') as src, gzip.open(y, 'wb') as dst:
                        shutil.copyfileobj(src, dst, 1<<20)
                x = os.stat(log)
                os.utime(y, (x.st_atime, x.st_mtime))
                os.remove(log)
            del compress[:]

    def tcpReset(self, nid):
        p = self.cluster.process_dict[NFQueue][nid-1]
        assert p.queue == nid, (p.queue, nid)
        try:
            p.lock.release()
        except ValueError:
            pass

    def restartStorages(self, nids):
        processes = [p for p in self.cluster.getStorageProcessList()
                       if p.uuid in nids]
        for p in processes: p.kill(signal.SIGKILL)
        time.sleep(1)
        for p in processes: p.wait()
        for p in processes: p.start()

    def _cleanFirewall(self):
        with open(os.devnull, "wb") as f:
            subprocess.call(('nft', 'delete', 'table',
                self._nft_family, self._nft_table), stderr=f)

    _ids_height = 4

    def refresh_ids(self, y):
        attr = curses.A_NORMAL, curses.A_BOLD
        stdscr = self.stdscr
        ltid = self.ltid
        stdscr.addstr(y, 0,
            'last oid: 0x%x\nlast tid: 0x%x (%s)\nclients: '
            % (u64(self.loid), u64(ltid), timeStringFromTID(ltid)))
        before = after = 0
        for i, p in enumerate(self.cluster.process_dict[Client]):
            if i:
                stdscr.addstr(', ')
            count = sum(p.count)
            before += p.prev_count
            after += count
            stdscr.addstr(str(count), attr[p.prev_count==count])
            p.prev_count = count
        elapsed = time.time() - self.started
        s, ms = divmod(int(elapsed * 1000), 1000)
        m, s = divmod(s, 60)
        stdscr.addstr(' (+%s)\n\t%sm%02u.%03us (%f/s)\n' % (
            after - before, m, s, ms, after / elapsed))


def console(port, app):
    from pdb import Pdb
    cluster = app.cluster
    def console(socket):
         Pdb(stdin=socket, stdout=socket).set_trace()
         app # this is Application instance
    s = socket.socket(cluster.address_type, socket.SOCK_STREAM)
    # XXX: The following commented line would only work with Python 3, which
    #      fixes refcounting of sockets (e.g. when there's a call to .accept()).
    #Process.on_fork.append(s.close)
    s.bind((cluster.local_ip, port))
    s.listen(0)
    while 1:
        t = threading.Thread(target=console, args=(PdbSocket(s.accept()[0]),))
        t.daemon = 1
        t.start()


class ArgumentDefaultsHelpFormatter(argparse.HelpFormatter):

    def _format_action(self, action):
        if not (action.help or action.default in (None, argparse.SUPPRESS)):
            action.help = '(default: %(default)s)'
        return super(ArgumentDefaultsHelpFormatter, self)._format_action(action)


def main():
    adapters = sorted(DATABASE_MANAGER_DICT)
    adapters.remove('Importer')
    default_adapter = 'SQLite'
    assert default_adapter in adapters

    kw = dict(formatter_class=ArgumentDefaultsHelpFormatter)
    parser = argparse.ArgumentParser(**kw)
    _ = parser.add_argument
    _('-6', '--ipv6', dest='address_type', action='store_const',
        default=socket.AF_INET, const=socket.AF_INET6, help='(default: IPv4)')
    _('-a', '--adapter', choices=adapters, default=default_adapter)
    _('-d', '--datadir', help="(default: same as unit tests)")
    _('-l', '--logdir', help="(default: same as --datadir)")
    _('-m', '--masters', type=int, default=1)
    _('-s', '--storages', type=int, default=8)
    _('-p', '--partitions', type=int, default=24)
    _('-r', '--replicas', type=int, default=1)
    parsers = parser.add_subparsers(dest='command')

    def ratio(value):
        value = float(value)
        if 0 <= value <= 1:
            return value
        raise argparse.ArgumentTypeError("ratio ∉ [0,1]")

    _ = parsers.add_parser('run',
        help='Start a new DB and fills it in a way that triggers many conflict'
             ' resolutions and deadlock avoidances. Stressing the cluster will'
             ' cause external faults every second, to check that NEO can'
             ' recover. The ingested data is checked at exit.',
        **kw).add_argument
    _('-c', '--clients', type=int, default=10,
        help='number of client processes')
    _('-t', '--threads', type=int, default=1,
        help='number of thread workers per client process')
    _('-r', '--restart-ratio', type=ratio, default=.5, metavar='RATIO',
        help='probability to kill/restart a storage node, rather than just'
             ' RSTing a TCP connection with this node')
    _('-C', '--console', type=int, default=0,
        help='console port (localhost) (default: any)')
    _('-D', '--delay', type=float, default=.01,
        help='randomly delay packets to storage nodes'
             '  by a duration between 0 and DELAY seconds')
    _('-L', '--logrotate', type=float, default=1, metavar='HOUR')

    _ = parsers.add_parser('check',
        help='Check ingested data.',
        **kw).add_argument
    _('tid', nargs='?')

    _ = parsers.add_parser('bisect',
        help='Search for the first TID that contains corrupted data.',
        **kw).add_argument

    args = parser.parse_args()

    db_list = ['stress_neo%s' % x for x in xrange(args.storages)]
    if args.datadir:
        if args.adapter != 'SQLite':
            parser.error('--datadir is only for SQLite adapter')
        db_list = [os.path.join(args.datadir, x + '.sqlite') for x in db_list]

    kw = dict(db_list=db_list, name='stress',
        partitions=args.partitions, replicas=args.replicas,
        adapter=args.adapter, address_type=args.address_type,
        temp_dir=args.logdir or args.datadir or getTempDirectory())

    if args.command == 'run':
        NFQueue.delay = args.delay
        app = Application(args.clients, args.threads, args.restart_ratio,
            int(round(args.logrotate * 3600, 0)), **kw)
        t = threading.Thread(target=console, args=(args.console, app))
        t.daemon = 1
        t.start()
        app.run()
        return

    cluster = NEOCluster(clear_databases=False, **kw)
    try:
        cluster.start()
        storage = cluster.getZODBStorage()
        db = DB(storage=storage)
        try:
            if args.command == 'check':
                tid = args.tid
                conn = db.open(at=tid and p64(int(tid, 0)))
                Client.check(conn.root())
            else:
                assert args.command == 'bisect'
                conn = db.open()
                try:
                    r = conn.root()
                    r._p_activate()
                    ok = r._p_serial
                finally:
                    conn.close()
                bad = storage.lastTransaction()
                while 1:
                    print('ok: 0x%x, bad: 0x%x' % (u64(ok), u64(bad)))
                    tid = p64((u64(ok)+u64(bad)) // 2)
                    if ok == tid:
                        break
                    conn = db.open(at=tid)
                    try:
                        Client.check(conn.root())
                    except SystemExit, e:
                        print(e)
                        bad = tid
                    else:
                        ok = tid
                    finally:
                        conn.close()
                print('bad: 0x%x (%s)' % (u64(bad), timeStringFromTID(bad)))
        finally:
            db.close()
    finally:
        cluster.stop()


if __name__ == '__main__':
    sys.exit(main())
