#!/usr/bin/env python
'''ThingsPrompt

Shell client for ThingsDB
'''
import os
import sys
import pprint
import argparse
import asyncio
import getpass
import re
import json
import base64
import ssl
import functools
from setproctitle import setproctitle
from thingsdb.client import Client
from thingsdb.exceptions import ThingsDBError
from prompt_toolkit import __version__ as ptk_version
from prompt_toolkit.filters import Condition
from prompt_toolkit.history import FileHistory
from prompt_toolkit.history import InMemoryHistory
from prompt_toolkit.key_binding import KeyBindings
from prompt_toolkit.patch_stdout import patch_stdout
from prompt_toolkit.shortcuts import PromptSession

__version__ = '1.0.1'  # keep equal to the one in setup.py

PTK3 = ptk_version.startswith('3.')
USE_FUN = re.compile(r'^\s*(@\s?[\:\/0-9a-zA-Z_]+)\s*$')
SCOPE_QUERY = re.compile(r'^\s*(@[\:0-9a-zA-Z_]+)\s+(.*)$')
SCOPE_RE = re.compile(
    r'^(@([a-z]*):([a-zA-Z_][a-zA-Z0-9_]*))'
    r'|(/([a-z]*)/([a-zA-Z_][a-zA-Z0-9_]*))$')


TAB = ' ' * 4
HELP = f'''
Version:
    {__version__}

Special commands:

?
    This help.
<@scope>
    Switch to another scope
<@scope> <query>
    Run a single query in a given scope
CTRL + n
    Insert a new line
'''

bindings = KeyBindings()
session = None


def collection_from_scope(scope: str):
    m = SCOPE_RE.match(scope)
    if m is None:
        return
    if m.group(1) is not None:
        return m.group(3) if 'collection'.startswith(m.group(2)) else None
    if m.group(4) is not None:
        return m.group(6) if 'collection'.startswith(m.group(5)) else None
    return None


class BinEncode(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, bytes):
            return base64.b64encode(obj).decode("utf-8")
        return json.JSONEncoder.default(self, obj)


@Condition
def is_active():
    return session.multiline


def on_enter_new_line(event):
    last_enter_idx = event.app.current_buffer.text.rfind('\n')
    last_line = event.app.current_buffer.text[last_enter_idx+1:].rstrip()

    idx = 0
    for idx, c in enumerate(last_line):
        if not c.isspace():
            break

    indent = last_line[:idx]

    if last_line and last_line[-1] in ('{', '(', '['):
        indent += TAB

    event.app.current_buffer.insert_text('\n' + indent)


@bindings.add('tab')
def _(event):
    """Insert TAB"""
    event.app.current_buffer.insert_text(TAB)


@bindings.add('c-n')
def _(event):
    on_enter_new_line(event)


@bindings.add('backspace')
def _(event):
    buffer = event.app.current_buffer
    text = buffer.text[:buffer.cursor_position]
    indent = len(text) - len(text.rstrip(' '))
    if indent and indent % 4 == 0:
        for _ in range(4):
            buffer.delete_before_cursor()
    else:
        buffer.delete_before_cursor()


async def connect(client, args, auth):
    await client.connect(args.node, args.port, timeout=args.timeout)
    await client.authenticate(*auth, timeout=args.timeout)


def set_prompt(client, session):
    scope = client.get_default_scope()
    title = f'{client.connection_info()} ({scope})'
    print(f'\33]0;{title}\a', end='', flush=True)
    session.message = f'{title}> '


async def prompt_loop(client, args):
    global session
    try:
        history_file = os.path.join(
            os.path.expanduser('~'),
            '.config',
            'ThingsPrompt',
            'history'
        )
        if not os.path.exists(history_file):
            path = os.path.dirname(history_file)
            if not os.path.exists(path):
                os.mkdir(path, 0o700)
            open(history_file, 'w').close()
            os.chmod(history_file, 0o600)

        history = FileHistory(history_file)
    except Exception as e:
        history = InMemoryHistory()

    session = PromptSession(history=history)
    session.client = client

    if PTK3:
        aprompt = functools.partial(
            session.prompt_async,
            key_bindings=bindings)
    else:
        aprompt = functools.partial(
            session.prompt,
            async_=True,
            key_bindings=bindings)

    set_prompt(client, session)

    while True:
        try:
            query = await aprompt()

            if query is None:
                continue

            if query.strip() == '?':
                print(HELP)
                continue

            use = USE_FUN.match(query)
            if use:
                scope = use.group(1)
                scope = scope.strip('\'"')
                if scope[1] == ' ':
                    scope = scope[2:]

                client.set_default_scope(scope)
                set_prompt(client, session)
                continue

            scope = SCOPE_QUERY.match(query)
            if scope:
                query = scope.group(2)
                scope = scope.group(1)
            else:
                scope = None

            if not client.is_connected():
                print('not connected')
                continue

            try:
                res = await client.query(
                    query,
                    scope=scope,
                    timeout=args.timeout)
            except ThingsDBError as e:
                print(f'{e.__class__.__name__}: {e}')
            else:
                print(json.dumps(res, sort_keys=True, indent=4, cls=BinEncode))

        except (EOFError, KeyboardInterrupt):
            return


async def do_export(client, fn: str, collection: str, dump: bool):
    with open(fn, 'wb') as f:
        try:
            dump = await client.query("""//ti
                export({dump:,});
            """, dump=dump, scope=f'//{collection}')
        except ThingsDBError as e:
            print(f'{e.__class__.__name__}: {e}')
        f.write(dump)


async def do_import(client, fn: str, collection: str, import_tasks: bool):
    with open(fn, 'rb') as f:
        data = f.read()
        has_collection = await client.has_collection(collection)
        try:
            if has_collection is False:
                await client.new_collection(collection)

            await client.query("""//ti
                import(data, {import_tasks:,});
            """, data=data, import_tasks=import_tasks, scope=f'//{collection}')
        except ThingsDBError as e:
            print(f'{e.__class__.__name__}: {e}')


if __name__ == '__main__':
    setproctitle('things-prompt')

    if not PTK3:
        from prompt_toolkit.eventloop.defaults import use_asyncio_event_loop
        use_asyncio_event_loop()

    parser = argparse.ArgumentParser()

    parser.add_argument(
        '--node', '-n',
        type=str,
        default='localhost',
        help='node address')

    parser.add_argument(
        '--port',
        type=int,
        default=9200,
        help='TCP port where the node is listening on for API calls')

    parser.add_argument(
        '--user', '-u',
        type=str,
        help='user name')

    parser.add_argument(
        '--password', '-p',
        type=str,
        help='password, will be prompted if not given')

    parser.add_argument(
        '--token', '-t',
        type=str,
        help='token key')

    parser.add_argument(
        '--scope', '-s',
        type=str,
        default='@thingsdb',
        help='set the initial scope')

    parser.add_argument(
        '--timeout',
        type=int,
        help='connect and query timeout in seconds')

    parser.add_argument(
        '--ssl',
        action='store_true',
        help='enable secure connection (SSL/TLS)')

    parser.add_argument(
        '--version',
        action='store_true',
        help='print version and exit')

    subparsers = parser.add_subparsers(help='sub-command help')
    parser_exp = subparsers.add_parser(
        'export',
        help='export a collection')

    parser_exp.add_argument(
        'filename',
        help='filename to store the export')

    parser_exp.add_argument(
        '--structure-only',
        action='store_true',
        help=(
            'generates a textual export with only enumerators, types and '
            'procedures; '
            'without this argument the export is not readable but in '
            'MessagePack format and intented to be used for import'))

    parser_imp = subparsers.add_parser(
        'import',
        help='export a collection')

    parser_imp.add_argument(
        'filename',
        help='filename to import')

    parser_imp.add_argument(
        '--tasks',
        action='store_true',
        help='include tasks when importing a collection')

    args = parser.parse_args()

    if args.version:
        sys.exit(__version__)

    if args.token is None:
        if args.user is None:
            sys.exit(
                'one of the arguments -t/--token or -u/--user is required')

        if args.password is None:
            args.password = getpass.getpass('password: ')

        auth = [args.user, args.password]
    else:
        if args.user is not None:
            sys.exit('use arguments -t/--token or -u/--user, not both')
        auth = [args.token]

    client = Client(ssl=ssl.SSLContext(ssl.PROTOCOL_TLS) if args.ssl else None)
    client.set_default_scope(args.scope)
    loop = asyncio.get_event_loop()

    try:
        loop.run_until_complete(connect(client, args, auth))
    except Exception as e:
        print(f'{e.__class__.__name__}: {e}', file=sys.stderr)
        exit(1)

    has_import = hasattr(args, 'tasks')
    has_export = hasattr(args, 'structure_only')

    if has_export:
        collection = collection_from_scope(args.scope)
        if collection is None:
            sys.exit(
                'not a valid collection scope; '
                'use --scope and provide a collection scope (e.g //stuff)')
        dump = not args.structure_only
        fn = args.filename
        loop.run_until_complete(do_export(client, fn, collection, dump))
    elif has_import:
        if not args.scope:
            sys.exit('argument --import requires a scope (--scope)')
        collection = collection_from_scope(args.scope)
        if collection is None:
            sys.exit('not a valid collection scope')
        fn = args.filename
        loop.run_until_complete(do_import(client, fn, collection, args.tasks))
    else:
        with patch_stdout():
            loop.run_until_complete(prompt_loop(client, args))

    client.close()
    loop.run_until_complete(client.wait_closed())
