#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2016 Adrien Vergé
# All rights reserved

import argparse
import base64
import configparser
from datetime import datetime
import fileinput
import getpass
from io import BytesIO
import json
import multiprocessing
import os
import random
import re
import resource
import shutil
import socket
import string
import subprocess
import sys
import tarfile
import tempfile
import threading
import time
from urllib.parse import quote, urlparse, urlsplit, urlunsplit
import urllib.request

import couchdb


def _check_couchdb_connection(url):
    url = url if url.startswith('http://') else 'http://' + url

    parts = list(urlsplit(url))
    server = parts[1]
    username, password = None, None
    if '@' in server:
        credentials, server = server.rsplit('@', 1)
        credentials = credentials.split(':', 1)
        username = credentials.pop(0)
        password = credentials.pop(0) if credentials else None
    while not username:
        print('CouchDB admin for %s: ' % server, end='', file=sys.stderr)
        username = input()
    while not password:
        password = getpass.getpass(
            'CouchDB password for %s@%s: ' % (username, server), sys.stderr)

    parts[1] = server
    url = urlunsplit(parts)       # http://server/db/
    parts[1] = '%s:%s@%s' % (quote(username, safe=[]),
                             quote(password, safe=[]),
                             server)
    auth_url = urlunsplit(parts)  # http://user:pass@server/db/

    # Check connection
    auth = base64.b64encode(('%s:%s' % (username, password))
                            .encode('utf-8')).decode('utf-8')
    req = urllib.request.Request(url, headers={
        'Authorization': 'Basic %s' % auth})
    try:
        response = urllib.request.urlopen(req)
        data = json.loads(response.read().decode('utf-8'))
        assert 'couchdb' in data and data['couchdb'] == 'Welcome', data
    except Exception as e:
        print('Cannot connect to CouchDB server %s:\n%s' % (server, e),
              file=sys.stderr)
        sys.exit(1)

    return auth_url


class CouchDBInstance(object):
    def __init__(self, erlang_node, standalone_server=False):
        self.erlang_node = erlang_node
        self.tempdir = tempfile.TemporaryDirectory(prefix='coucharchive-')
        self.thread = None
        self.url = None
        self.standalone_server = standalone_server

        self._setup()

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        if self.thread is not None:
            self.stop()

    @property
    def confdir(self):
        return self.tempdir.name + '/etc'

    @property
    def datadir(self):
        return self.tempdir.name + '/data'

    def _random_credential(self):
        return 'root', ''.join(
            random.choice(string.ascii_letters + string.digits)
            for _ in range(10))

    def _two_unused_ports(self):
        s1 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        s1.bind(('localhost', 0))
        _, port1 = s1.getsockname()
        s2 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        s2.bind(('localhost', 0))
        _, port2 = s2.getsockname()
        s1.close()
        s2.close()
        return port1, port2

    def _try_to_increase_rlimit_nofile(self):
        MAX_NOFILE = 4096
        # Try to increase the allowed number of open files, to avoid CouchDB
        # errors:
        soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
        if soft < MAX_NOFILE and soft < hard:
            print('Trying to increase the max number of open files '
                  '(currently %d)...' % soft, file=sys.stderr)
            resource.setrlimit(resource.RLIMIT_NOFILE,
                               (min(MAX_NOFILE, hard), hard))
            soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
        if soft < MAX_NOFILE:
            print(('WARNING: Max number of open files is low (%d), it could '
                   'result in server errors. If you have errors, consider '
                   'increasing the system hard limit.') % soft,
                  file=sys.stderr)

    def _setup(self):
        self._try_to_increase_rlimit_nofile()

        os.mkdir(self.confdir)
        os.mkdir(self.confdir + '/local.d')
        os.mkdir(self.datadir)

        self.creds = self._random_credential()
        self.ports = self._two_unused_ports()

        for file in ('vm.args', 'default.ini', 'local.ini'):
            shutil.copy('/etc/couchdb/' + file, self.confdir + '/' + file)

        for line in fileinput.input(self.confdir + '/vm.args', inplace=True):
            if re.match(r'^-name \S+$', line):
                print('-name ' + self.erlang_node)
            else:
                print(line, end='')

        with open(self.confdir + '/local.d/coucharchive.ini', 'w') as f:
            f.write('[chttpd]\n'
                    'port = %d\n' % self.ports[0] +
                    '\n'
                    '[httpd]\n'
                    'port = %d\n' % self.ports[1] +
                    '\n'
                    '[couchdb]\n'
                    'database_dir = %s\n' % self.datadir +
                    'view_index_dir = %s\n' % self.datadir +
                    'max_dbs_open = 2000\n'
                    '\n'
                    '[cluster]\n'
                    'q=1\n'  # ideal for a small, 1-node setup
                    'n=1\n'
                    '\n'
                    '[admins]\n'
                    '%s = %s\n' % self.creds)

    def start(self):
        env = dict(os.environ,
                   COUCHDB_VM_ARGS=self.confdir + '/vm.args',
                   COUCHDB_INI_FILES=(self.confdir + '/default.ini ' +
                                      self.confdir + '/local.ini ' +
                                      self.confdir + '/local.d'))
        log = open(self.tempdir.name + '/log', 'w')

        class CouchDBRunnerThread(threading.Thread):
            def __init__(self):
                super().__init__()
                self.process = None

            def run(self):
                self.process = subprocess.Popen('couchdb', env=env,
                                                stdout=log, stderr=log)
                self.process.wait()

            def terminate(self):
                self.process.terminate()

        self.thread = CouchDBRunnerThread()
        self.thread.start()

        self.url = 'http://%s:%s@localhost:%d' % (self.creds + self.ports[:1])

        if self.standalone_server:
            return

        for i in range(25):
            if not self.thread.is_alive():
                raise Exception('CouchDB process died')
            try:
                self.version = (urllib.request.urlopen('http://localhost:%d'
                                                       % self.ports[0])
                                .read().decode('utf-8'))
                if '"couchdb":"Welcome"' in self.version:
                    return

                self.thread.terminate()
                raise Exception('CouchDB answered: %s' % self.version)
            except urllib.error.URLError:
                time.sleep(0.2)

        self.thread.terminate()
        raise Exception('CouchDB server does not answer after 5 seconds')

    def stop(self):
        print('Terminating local CouchDB instance', file=sys.stderr)
        self.thread.terminate()
        self.thread.join()
        self.thread = None


def replicate_couchdb_server(source_url, target_url, ignore_dbs=[]):
    while source_url.endswith('/'):
        source_url = source_url[:-1]
    while target_url.endswith('/'):
        target_url = target_url[:-1]

    ignore_dbs += ('_global_changes', '_metadata', '_replicator')
    all_dbs = [db for db in list(couchdb.Server(source_url))
               if db not in ignore_dbs]

    todos = [(source_url, target_url, db) for db in all_dbs]

    pool = multiprocessing.Pool(processes=16)
    try:  # see https://stackoverflow.com/a/25791961
        list(pool.imap_unordered(replicate_one_database, todos))
    except Exception as e:
        print('A replication failed, stopping...', file=sys.stderr)
        pool.close()
        pool.terminate()
        raise
    else:
        pool.close()
        pool.join()


def replicate_one_database(args):
    timeout = 3

    source_url, target_url, db = args

    source = couchdb.Server(source_url)
    target = couchdb.Server(target_url)

    source_host = (urlparse(source_url).netloc
                   .rsplit('@', 1)[-1].rsplit(':', 1)[0])
    source_is_local = source_host in ('localhost', '127.0.0.1', '::1')

    try:
        target.create(db)
    except couchdb.http.PreconditionFailed as e:
        if e.args[0][0] == 'file_exists' and db in ('_users',):
            pass
        else:
            print(db)
            raise

    server = source if source_is_local else target
    server.replicate(source_url + '/' + db, target_url + '/' + db)

    source_db = couchdb.Database(source_url + '/' + db)
    target_db = couchdb.Database(target_url + '/' + db)

    while True:
        try:
            target_db.security = source_db.security
            break
        except couchdb.http.ServerError as e:
            if timeout == 0:
                if e.args[0][1][1] in ('no_majority', 'no_ring'):
                    print('Retry with a greater ulimit '
                          '(e.g. `ulimit -n 8192`)', file=sys.stderr)
                raise
            time.sleep(1)
            timeout -= 1

    while True:
        source_len, target_len = len(source_db), len(target_db)
        if source_len == target_len:
            break
        elif timeout == 0:
            raise Exception(
                '%s: replicated database has %d docs, source has %d'
                % (db, target_len, source_len))
        time.sleep(1)
        timeout -= 1

    print('%s: done' % db)


def create(source, filename, ignore_dbs=[]):
    erlang_node = 'coucharchive-%s@localhost' % ''.join(
        random.choice(string.ascii_letters + string.digits) for _ in range(10))

    with CouchDBInstance(erlang_node) as local_couchdb:
        local_couchdb.start()
        print('Launched CouchDB instance at %s' % local_couchdb.url,
              file=sys.stderr)

        replicate_couchdb_server(source, local_couchdb.url,
                                 ignore_dbs=ignore_dbs)

        local_couchdb.stop()

        print('Creating backup archive at %s' % filename, file=sys.stderr)
        with tarfile.open(filename, 'w:gz') as tar:
            tar.add(local_couchdb.confdir, arcname='etc')
            tar.add(local_couchdb.datadir, arcname='data')

            file = tarfile.TarInfo('erlang_node_name')
            file.size = len(erlang_node)
            tar.addfile(file, BytesIO(erlang_node.encode('utf-8')))

            info = (
                'CouchDB backup made on %s\n' % datetime.now().isoformat() +
                'with CouchDB version %s\n' % local_couchdb.version
            ).encode('utf-8')
            file = tarfile.TarInfo('info')
            file.size = len(info)
            tar.addfile(file, BytesIO(info))


def _load_archive(filename, callback):
    if not os.path.isfile(filename):
        raise Exception('File "%s" does not exist' % filename)

    with tarfile.open(filename) as tar, \
            tempfile.TemporaryDirectory(prefix='coucharchive-') as tmp:
        print('Extracting backup archive from %s' % filename, file=sys.stderr)
        tar.extractall(path=tmp)

        if os.path.isfile(tmp + '/erlang_node_name'):
            with open(tmp + '/erlang_node_name', 'r') as f:
                erlang_node = f.read().strip()
        else:  # for archives made before coucharchive 1.2.1
            erlang_node = 'coucharchive@localhost'

        with CouchDBInstance(erlang_node) as local_couchdb:
            os.rmdir(local_couchdb.datadir)
            os.rename(tmp + '/data', local_couchdb.datadir)

            local_couchdb.start()
            print('Launched CouchDB instance at %s' % local_couchdb.url,
                  file=sys.stderr)

            callback(local_couchdb.url)


def load(filename):
    def callback(local_couch_server_url):
        print('Ready!', file=sys.stderr)
        try:
            time.sleep(365 * 24 * 3600)
        except KeyboardInterrupt:
            pass

    _load_archive(filename, callback)


def restore(target, filename, ignore_dbs=[]):
    def callback(local_couch_server_url):
        replicate_couchdb_server(local_couch_server_url,
                                 target, ignore_dbs=ignore_dbs)

    _load_archive(filename, callback)


def main():
    # Get action and archive file from command line
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--config', dest='config_file',
                        action='store', help='path to config file')
    subparsers = parser.add_subparsers(dest='action')
    sub = {}

    sub['create'] = subparsers.add_parser('create')
    sub['create'].add_argument(
        '--from', dest='source_server', action='store',
        help='source CouchDB server to create archive from')
    sub['create'].add_argument(
        '-o', '--output', dest='output', action='store', required=True,
        help='path to archive to create')

    sub['restore'] = subparsers.add_parser('restore')
    sub['restore'].add_argument(
        '--to', dest='target_server', action='store',
        help='target CouchDB server to restore archive to')
    sub['restore'].add_argument(
        '-i', '--input', dest='input', action='store', required=True,
        help='path to archive to restore')

    sub['load'] = subparsers.add_parser('load')
    sub['load'].add_argument(
        '-i', '--input', dest='input', action='store', required=True,
        help='path to archive to load')

    sub['replicate'] = subparsers.add_parser('replicate')
    sub['replicate'].add_argument(
        '--from', dest='source_server', action='store',
        help='source CouchDB server to replicate from')
    sub['replicate'].add_argument(
        '--to', dest='target_server', action='store',
        help='target CouchDB server to replicate to')

    args = parser.parse_args()

    config = configparser.ConfigParser()
    if args.config_file:
        config.read(args.config_file)

    if args.action in ('create', 'replicate'):
        if not args.source_server and 'source' in config.sections():
            args.source_server = config['source'].get('url', '')
        if not args.source_server:
            sub[args.action].print_help()
            parser.exit(1)
    if args.action in ('restore', 'replicate'):
        if not args.target_server and 'target' in config.sections():
            args.target_server = config['target'].get('url', '')
        if not args.target_server:
            sub[args.action].print_help()
            parser.exit(1)

    ignore_dbs = []
    if 'replication' in config.sections():
        ignore_dbs = config['replication'].get('ignore_dbs', '').split(',')
        ignore_dbs = [db.strip() for db in ignore_dbs if db.strip()]

    if getattr(args, 'source_server', None):
        args.source_server = _check_couchdb_connection(args.source_server)
    if getattr(args, 'target_server', None):
        args.target_server = _check_couchdb_connection(args.target_server)

    if args.action == 'create':
        create(args.source_server, args.output, ignore_dbs=ignore_dbs)
    elif args.action == 'restore':
        restore(args.target_server, args.input, ignore_dbs=ignore_dbs)
    elif args.action == 'load':
        load(args.input)
    elif args.action == 'replicate':
        replicate_couchdb_server(args.source_server, args.target_server,
                                 ignore_dbs=ignore_dbs)
    else:
        parser.print_help()
        parser.exit(1)


if __name__ == '__main__':
    main()
