#!/usr/bin/env python3
# Copyright 2021 Tolteck

import argparse
import asyncio
from datetime import datetime
import fileinput
import os
import random
import re
import shutil
import socket
import string
import sys
from tempfile import TemporaryDirectory

import aiocouch
import yaml


__version__ = '0.1.3'
# Tweak this parameter to your needs: From 17 minutes with 16 workers to 28
# minutes with 8 workers for 10^5 databases on my machine.
N_WORKERS = 16


async def subprocess(*args, shell=False):
    # For debugging purposes:
    # print('> ' + str(args))
    if shell:
        p = await asyncio.create_subprocess_shell(
            *args, stdout=asyncio.subprocess.PIPE,
            stderr=asyncio.subprocess.PIPE)
    else:
        p = await asyncio.create_subprocess_exec(
            *args, stdout=asyncio.subprocess.PIPE,
            stderr=asyncio.subprocess.PIPE)
    stdout, stderr = await p.communicate()
    stdout, stderr = stdout.decode(), stderr.decode()
    if p.returncode == 0:
        return stdout
    raise Exception(f'Command {args[0]} returned code {p.returncode}, stdout: '
                    f'{stdout}, stderr: {stderr}')


async def backup(hostname, path, output, reuse_dir=None, tmp_dir=None,
                 nodes_names=None):
    tmp_dir = TemporaryDirectory(prefix='couchcopy-', dir=tmp_dir)
    tmp_path = reuse_dir or tmp_dir.name
    if not os.path.exists(tmp_path + '/data'):
        os.mkdir(tmp_path + '/data')

    with open(tmp_path + '/metadata.yaml', 'w') as f:
        yaml.dump({
            'backup': {
                'date': datetime.now().isoformat(),
                'source': {'nodes-names': nodes_names or []},
            },
            'couchcopy': {'version': __version__},
        }, f, default_flow_style=False)

    # rsync is used to copy from remote machine, but also to save files
    # in a specific order, and to avoid using `tar` on files that are used in
    # parallel by a running CouchDB.
    print('[rsync...]')
    hostname = hostname + ':' if hostname != 'localhost' else ''
    await subprocess('rsync', '--del', '--ignore-missing-args',
                     '-av', hostname + path + '/.shards', tmp_path + '/data')
    await subprocess('rsync', '--del', '-av', hostname + path + '/_dbs.couch',
                     tmp_path + '/data')
    await subprocess('rsync', '--del', '--ignore-missing-args',
                     '-av', hostname + path + '/shards', tmp_path + '/data')

    print('[tar...]')
    await subprocess('tar', '-I', 'pigz', '-cf', output, '-C', tmp_path,
                     'metadata.yaml', 'data')
    # For debugging purposes:
    # if tmp_dir:
    #     shutil.rmtree('/tmp/couchcopy', True)
    #     shutil.copytree(tmp_dir.name, '/tmp/couchcopy')


async def couch_conn(url, user, password):
    conn = aiocouch.CouchDB(url, user=user, password=password)
    for i in range(6):
        try:
            await conn.info()
            return conn
        except Exception as e:
            error = e
            await asyncio.sleep(0.25 * 2**i)
    await conn.close()
    raise Exception('Cannot connect to CouchDB server: ' + repr(error))


async def do_in_parallel(func, generator, url, user, password):
    async def work_producer(generator, queue):
        block = []
        # Give the consumers block of 100 databases:
        async for item in generator:
            block.append(item)
            if len(block) >= 100:
                await queue.put(block)
                block = []
        if len(block):
            await queue.put(block)

    async def work_consumer(func, queue):
        async with await couch_conn(url, user, password) as couch:
            while True:
                try:
                    block = await queue.get()
                except asyncio.exceptions.CancelledError:
                    break
                for item in block:
                    await func(couch, item)
                queue.task_done()
    queue = asyncio.Queue(maxsize=N_WORKERS)
    # Launch a work "producer" and N parallel "consumer" workers:
    producer = asyncio.create_task(work_producer(generator, queue))
    consumers = [asyncio.create_task(work_consumer(func, queue))
                 for i in range(N_WORKERS)]
    # Wait for either:
    # 1. The producer to have output all databases.
    # 2. Any consumer to return (can happen before 1 in case of an exception).
    await asyncio.wait([producer, *consumers],
                       return_when=asyncio.FIRST_COMPLETED)
    if producer.done():
        # Wait for either:
        # 1. The queue to be fully consumed (consumers to have finished
        #    processing the last items).
        # 2. Any consumer to return (can happen before 1 in case of an
        #    exception).
        await asyncio.wait([asyncio.create_task(queue.join()), *consumers],
                           return_when=asyncio.FIRST_COMPLETED)
    for consumer in consumers:
        consumer.cancel()
    # Exceptions would be raised now:
    await asyncio.gather(*consumers)


async def aio_all_dbs(couch):
    last = None
    while True:
        # Be reasonable and limit to 1000 results per call.
        dbs = await couch.keys(start_key=last, limit=1000)
        if len(dbs) == 0:
            break
        for db_name in dbs:
            yield db_name

        # Use \u0020 (space) because \u0000 is not accepted by CouchDB UCA
        # (Unicode Collation Algorithm) sorter.
        last = '"' + db_name + '\u0020"'


async def change_nodes_names(user, password, url, names):
    # Do a few checks on data from the first database.
    async with await couch_conn(url, user, password) as couch:
        dbs = await couch.keys(limit=1)
        if not len(dbs):
            print('No databases listed inside CouchDB')
            return
        _, data = await couch._server._get(f'/_node/_local/_dbs/{dbs[0]}')
        current_names = data['by_node'].keys()
        shard_ranges = sorted(list(data['by_range'].keys()))
        q_from_archive = len(shard_ranges)
        _, data = await couch._server._get('/_node/_local/_config')
        q_from_cluster = int(data['cluster']['q'])
        assert q_from_cluster == q_from_archive, (
            f'Error q from CouchDB ({q_from_cluster}) != q from archive '
            f'({q_from_archive}), you need to change the `q` value used by '
            f'CouchDB. For more infos see the README.')
        if sorted(names) == sorted(current_names):
            print('CouchDB nodes names already good in shards')
            return

    # Update cluster metadata on CouchDB first node.
    # Metadata are automatically transfered to the other nodes by CouchDB.
    # To understand what below code do, have a look at
    # GET /_node/_local/_dbs/{db} endpoint:
    # https://docs.couchdb.org/en/3.1.1/cluster/sharding.html#updating-cluster-metadata-to-reflect-the-new-target-shard-s
    print('[Updating CouchDB metadata...]')

    async def update_one_db_metadata(couch, db):
        _, data = await couch._server._get(f'/_node/_local/_dbs/{db}')
        data['changelog'] = [['add', shard_range, name]
                             for shard_range in shard_ranges
                             for name in names]
        data['by_node'] = {name: shard_ranges for name in names}
        data['by_range'] = {shard_range: names for shard_range in shard_ranges}
        await couch._server._put(f'/_node/_local/_dbs/{db}', data=data)

    async with await couch_conn(url, user, password) as couch:
        await do_in_parallel(update_one_db_metadata, aio_all_dbs(couch),
                             url, user, password)


async def load(archive, admin=None, tmp_dir=None, blocking=True):
    tmp_dir = TemporaryDirectory(prefix='couchcopy-', dir=tmp_dir)
    os.mkdir(tmp_dir.name + '/etc')
    os.mkdir(tmp_dir.name + '/etc/local.d')
    os.mkdir(tmp_dir.name + '/data')
    creds = ['admin', 'password'] if not admin else admin.split(':')
    s1 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s1.bind(('localhost', 0))
    _, port1 = s1.getsockname()
    s2 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s2.bind(('localhost', 0))
    _, port2 = s2.getsockname()
    s1.close()
    s2.close()

    for file in ('vm.args', 'default.ini', 'local.ini'):
        shutil.copy('/etc/couchdb/' + file, tmp_dir.name + '/etc/' + file)

    with open(tmp_dir.name + '/etc/local.d/couchcopy.ini', 'w') as f:
        f.write(f'[chttpd]\n'
                f'port = {port1}\n'
                f'\n'
                f'[httpd]\n'
                f'port = {port2}\n'
                f'\n'
                f'[couchdb]\n'
                f'database_dir = {tmp_dir.name}/data\n'
                f'view_index_dir = {tmp_dir.name}/data\n'
                f'max_dbs_open = 10000\n'
                f'users_db_security_editable = true\n'  # for CouchDB 3+
                f'\n'
                f'[cluster]\n'
                f'q=2\n'  # Change this to match your origin cluster `q`
                f'n=1\n'
                f'\n'
                f'[smoosh]\n'
                f'db_channels = ,\n'
                f'view_channels = ,\n'
                f'\n'
                f'[admins]\n'
                f'{creds[0]} = {creds[1]}\n')

    print('[untar...]')
    await subprocess('tar', '-I', 'pigz', '-xf', archive, '-C', tmp_dir.name)

    with open(tmp_dir.name + '/metadata.yaml') as f:
        conf = yaml.safe_load(f)
        nodes_names = conf['backup']['source']['nodes-names']

    node_name = nodes_names[0] if nodes_names else \
        'couchcopy-%s@localhost' % ''.join(random.choice(
            string.ascii_letters + string.digits) for _ in range(10))

    for line in fileinput.input(tmp_dir.name + '/etc/vm.args', inplace=True):
        if re.match(r'^-name \S+$', line):
            print('-name ' + node_name)
        else:
            print(line, end='')

    env = dict(os.environ,
               COUCHDB_ARGS_FILE=tmp_dir.name + '/etc/vm.args',
               COUCHDB_INI_FILES=(tmp_dir.name + '/etc/default.ini ' +
                                  tmp_dir.name + '/etc/local.ini ' +
                                  tmp_dir.name + '/etc/local.d'))
    log = open(tmp_dir.name + '/log', 'w')
    process = await asyncio.create_subprocess_exec(
        'couchdb', env=env, stdout=log, stderr=log)

    await change_nodes_names(creds[0], creds[1], f'http://localhost:{port1}',
                             [node_name])

    if blocking:
        print(f'Launched CouchDB instance at '
              f'http://{":".join(creds)}@localhost:{port1}')
    else:
        process.terminate()
    await process.wait()

    # For debugging purposes:
    # if tmp_dir:
    #     shutil.rmtree('/tmp/couchcopy', True)
    #     shutil.copytree(tmp_dir.name, '/tmp/couchcopy')

    return node_name, tmp_dir


async def unbrand(old_archive, new_archive, tmp_dir=None):
    node_name, tmp_dir = await load(old_archive, tmp_dir=tmp_dir,
                                    blocking=False)
    await backup('localhost', tmp_dir.name + '/data', new_archive,
                 nodes_names=[node_name])


async def restore(archive, cred, hostnames, paths, ports, names, couchdb_start,
                  couchdb_stop, force, use_sudo):
    user, password = cred.split(':')
    remote = hostnames[0] != 'localhost'
    urls = [f'http://{hostname if remote else "localhost"}:{port}'
            for hostname, port in zip(hostnames, ports)]
    sudo = ['sudo'] if use_sudo else []
    couchdb_start = 'sudo ' + couchdb_start if use_sudo else couchdb_start
    couchdb_stop = 'sudo ' + couchdb_stop if use_sudo else couchdb_stop

    print('[Checking CouchDB nodes names and n...]')
    if remote:
        await subprocess('ssh', hostnames[0], couchdb_start)
    else:
        await subprocess(couchdb_start, shell=True)

    async with await couch_conn(urls[0], user, password) as couch:
        current_names = sorted((
            await couch._server._get('/_membership'))[1]['cluster_nodes'])
        assert current_names == sorted(names), (
            f'Error in names: {current_names} != {sorted(names)}. Try to '
            f'change nodes names with [nodename] arguments.')
        _, data = await couch._server._get('/_node/_local/_config')
        assert int(data['cluster']['n']) >= len(names), (
            'Error n < nodes count, this is not supported, for more infos see '
            'the README.')

    # Stop CouchDB and delete existing data.
    if not force:
        dirs = ' & '.join([hostname + ':' + path if hostname else path
                           for hostname, path in zip(hostnames, paths)])
        print(f'This command will wipe-out directories {dirs}, '
              f'is it OK? [y/N]')
        answer = input()
        if answer not in ('Y', 'y'):
            print('Operation aborted.')
            sys.exit(1)
    print('[rm...]')
    for hostname, path in zip(hostnames, paths):
        if remote:
            await subprocess('ssh', hostname, couchdb_stop)
            await subprocess(
                'ssh', hostname, *sudo, 'rm', '-rf', path + '/_dbs.couch',
                path + '/_users.couch', path + '/.delete',
                path + '/._users_design', path + '/.shards', path + '/shards')
        else:
            await subprocess(couchdb_stop, shell=True)
            await subprocess(
                *sudo, 'rm', '-rf', path + '/_dbs.couch',
                path + '/_users.couch', path + '/.delete',
                path + '/._users_design', path + '/.shards', path + '/shards')

    if remote:
        # There is a strange issue: if a majority of CouchDB nodes don't
        # have the `shards` directory on startup, `_security` of databases are
        # reseted to their default values.
        # This issue is possibly related to:
        # https://github.com/apache/couchdb/issues/1611
        # A workaround is to copy `shards` directory to all nodes (instead of
        # just one), it's what is done here.
        async def rsync(hostname, path):
            await subprocess(
                'rsync', '-av', archive, hostname + ':/tmp/couchcopy.tar.gz')

        async def untar(hostname, path):
            # Untar `_dbs.couch` only for the first node.
            exclude = ('' if hostname is hostnames[0] else
                       '--exclude=_dbs.couch')
            await subprocess('ssh', hostname, *sudo, 'tar', '-I', 'pigz',
                             '--strip-components=1', '-xf',
                             '/tmp/couchcopy.tar.gz', '-C', path, exclude)
            await subprocess('ssh', hostname, *sudo, 'chown', '-R',
                             'couchdb:couchdb', path)
            # `tar` overwrite permissions, restore them.
            if use_sudo:
                await subprocess('ssh', hostname, *sudo, 'chmod', '+rx', path)

        print('[rsync...]')
        await asyncio.gather(*(rsync(hostname, path)
                               for hostname, path in zip(hostnames, paths)))
        print('[untar...]')
        await asyncio.gather(*(untar(hostname, path)
                               for hostname, path in zip(hostnames, paths)))
    else:
        print('[untar...]')
        await subprocess(*sudo, 'tar', '-I', 'pigz', '--strip-components=1',
                         '-xf', archive, '-C', paths[0])
        if use_sudo:
            await subprocess(*sudo, 'chown', '-R', 'couchdb:couchdb', paths[0])
            # `tar` overwrite permissions, restore them.
            await subprocess(*sudo, 'chmod', '+rx', paths[0])

    # Start first CouchDB node.
    if remote:
        await subprocess('ssh', hostnames[0], couchdb_start)
    else:
        await subprocess(couchdb_start, shell=True)

    await change_nodes_names(user, password, urls[0], names)
    print(f'CouchDB node {names[0]} is now restored, you can use it!')

    # Start other CouchDB nodes.
    for hostname in hostnames[1:]:
        if remote:
            await subprocess('ssh', hostname, couchdb_start)
        else:
            await subprocess(couchdb_start, shell=True)

    # Wait for the same number of databases on each nodes.
    print('[Waiting for CouchDB cluster synchronization...]')
    async with await couch_conn(urls[0], user, password) as couch:
        _, data = await couch._server._get('/_dbs')
        node1_dbs_count = data['doc_count'] + data['doc_del_count']
    couchs = [await couch_conn(url, user, password) for url in urls[1:]]
    try:
        dones = []
        while len(dones) < len(couchs):
            for couch, name in zip(couchs, names[1:]):
                if name in dones:
                    continue
                _, data = await couch._server._get('/_dbs')
                dbs_count = data['doc_count'] + data['doc_del_count']
                print(
                    f'{name} sync {(dbs_count / node1_dbs_count) * 100:.0f} %')
                if dbs_count >= node1_dbs_count:
                    dones.append(name)
            await asyncio.sleep(2)
    finally:
        for couch in couchs:
            await couch.close()


async def main():
    examples = (
        'examples:\n'
        '  couchcopy backup old-server,/var/lib/couchdb backup.tar.gz\n'
        '  couchcopy load backup.tar.gz\n'
        '  couchcopy unbrand slow_backup.tar.gz quick_backup.tar.gz\n'
        '  couchcopy restore backup.tar.gz \\\n'
        '      admin:password@cluster_vm1,/var/lib/couchdb \\\n'
        '      admin:password@cluster_vm2,/var/lib/couchdb \\\n'
        '      admin:password@cluster_vm3,/var/lib/couchdb\n')

    parser = argparse.ArgumentParser(
            prog='couchcopy', epilog=examples,
            # needed for examples
            formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('-v', '--version', action='store_true')
    subparsers = parser.add_subparsers(dest='action')
    sub = {}
    sub['backup'] = subparsers.add_parser('backup', description=(
        'Backup a CouchDB cluster from one of its node.'))
    sub['backup'].add_argument('couchdb',
                               metavar='hostname,/couchdb/data/dir')
    sub['backup'].add_argument('archive', metavar='backup.tar.gz')
    sub['backup'].add_argument('--rsync-reuse-dir', help=(
        'directory on the local machine to store data reused between '
        'executions, if not set rsync starts from scratch'))
    sub['backup'].add_argument('--tmp-dir', help=(
        'directory on the local machine to store temporary data'))
    sub['unbrand'] = subparsers.add_parser('unbrand', description=(
        'Unbrand shards inside a backup.tar.gz file from their origin CouchDB '
        'node name to a unique and reusable node name. Use this option to '
        'improve the local load speed of the archive.'))
    sub['unbrand'].add_argument('old_archive', metavar='backup.tar.gz')
    sub['unbrand'].add_argument('new_archive', metavar='new_backup.tar.gz')
    sub['unbrand'].add_argument('--tmp-dir', help=(
        'directory on the local machine to store temporary data'))
    sub['load'] = subparsers.add_parser('load', description=(
        'Spawn a local CouchDB instance and load data into it.'))
    sub['load'].add_argument('archive', metavar='backup.tar.gz')
    sub['load'].add_argument('--admin', metavar='ADMIN:PASSWORD', help=(
        'Set the CouchDB cluster admin user.'))
    sub['load'].add_argument('--tmp-dir', help=(
        'directory on the local machine to store temporary data'))
    sub['restore'] = subparsers.add_parser('restore', description=(
        'Restore a full cluster from a backup.'))
    sub['restore'].add_argument('archive', metavar='backup.tar.gz')
    meta = '[admin:password@]hostname[:5984],/couchdb/data/dir[,nodename]'
    sub['restore'].add_argument(
        'couchdbs', metavar=meta, nargs='+', help=(
            'data needed to connect and overwrite existing CouchDB node. If '
            'hostname is not "localhost", ssh will be used to connect to the '
            'remote machine. nodename default is "couchdb@127.0.0.1" if '
            'hostname is "localhost", otherwise it is "couchdb@<hostname>"'))
    sub['restore'].add_argument('--couchdb-start',
                                default='systemctl start couchdb',
                                help='command-line used to start CouchDB')
    sub['restore'].add_argument('--couchdb-stop',
                                default='systemctl stop couchdb',
                                help='command-line used to stop CouchDB')
    sub['restore'].add_argument('-y', action='store_true', help=(
        'delete existing CouchDB files without asking for confirmation'))
    sub['restore'].add_argument('--use-sudo', action='store_true')

    args = parser.parse_args()
    if args.version:
        print(__version__)
        return
    elif not args.action:
        parser.error('no action given')
    elif args.action == 'backup':
        splitted = args.couchdb.split(',')
        if len(splitted) != 2:
            sub['backup'].error('wrong "hostname,/couchdb/data/dir"')
        hostname, path = splitted
        if hostname and any(c in hostname for c in (':', '@')):
            sub['backup'].error('wrong "hostname,/couchdb/data/dir"')
        assert not args.rsync_reuse_dir or not args.tmp_dir, \
            'cannot use --tmp-dir and --rsync-reuse-dir in the same time'

        await backup(hostname, path, args.archive, args.rsync_reuse_dir,
                     args.tmp_dir)
    elif args.action == 'unbrand':
        await unbrand(args.old_archive, args.new_archive, args.tmp_dir)
    elif args.action == 'load':
        await load(args.archive, args.admin, args.tmp_dir)
    elif args.action == 'restore':
        creds = []
        hostnames = []
        ports = []
        paths = []
        names = []
        for couchdb in args.couchdbs:
            splitted = couchdb.split(',', 1)
            if len(splitted) != 2:
                sub['restore'].error(f'wrong {meta} {couchdb}')
            access, couchdb_internal = splitted
            cred, hostname_and_port = ['admin:password', access] \
                if len(access.split('@')) == 1 else access.split('@')
            if len(cred.split(':')) != 2:
                sub['restore'].error('wrong credentials')
            hostname, port = ([hostname_and_port, '5984']
                              if len(hostname_and_port.split(':')) == 1
                              else hostname_and_port.split(':'))
            default_name = 'couchdb@' + (
                '127.0.0.1' if hostname == 'localhost' else hostname)
            path, name = ([couchdb_internal, default_name]
                          if len(couchdb_internal.split(',')) == 1
                          else couchdb_internal.split(','))
            if len(name.split('@')) != 2:
                sub['restore'].error(f'wrong nodename {name}')
            creds.append(cred)
            hostnames.append(hostname)
            ports.append(port)
            paths.append(path)
            names.append(name)

        await restore(args.archive, creds[0], hostnames, paths, ports, names,
                      args.couchdb_start, args.couchdb_stop, args.y,
                      args.use_sudo)

    print('Done!')


if __name__ == '__main__':
    asyncio.run(main())
