#!/usr/bin/env python3

# SPDX-License-Identifier: GPL-3.0-or-later
# Copyright (C) 2019 Michał Góral.

import os
import sys
import argparse
import collections
import asyncio
import subprocess
import struct
import tempfile
from contextlib import asynccontextmanager
from dataclasses import dataclass

ASYNC_PIPE = asyncio.subprocess.PIPE
PIPE = subprocess.PIPE


class Command:
    def __init__(self, optstr):
        self.display, _, remainder = optstr.partition("::")
        self._command, _, self.flags = remainder.partition("::")

        assert self.display, "Incorrect command format: {}".format(optstr)
        assert self._command, "Incorrect command format: {}".format(optstr)

    @property
    def command(self):
        return reformat_cmd(self._command)


@dataclass
class Socket:
    reader: asyncio.StreamReader
    writer: asyncio.StreamWriter
    _lock: asyncio.Lock = asyncio.Lock()

    async def send(self, msg):
        assert (
            not self.writer.is_closing()
        ), "can't send new data while closing the socket"

        data = msg.encode()
        msglen = struct.pack("!i", len(data))
        async with self._lock:
            self.writer.writelines([msglen, data])
            await self.writer.drain()

    async def recv(self):
        async with self._lock:
            msglen = await self.reader.read(4)
            size = struct.unpack("!i", msglen)[0]
            data = await self.reader.read(size)
            return data.decode()

    async def close(self):
        async with self._lock:
            # reader has no close() method
            self.writer.close()
            await self.writer.wait_closed()


def eprint(*a, **kw):
    kw["file"] = sys.stderr
    print(*a, **kw)


def default_sock_path():
    tmp = tempfile.gettempdir()
    fname = "kpsh-{}.sock".format(os.getuid())
    return os.path.join(tmp, fname)


@asynccontextmanager
async def connect(path):
    reader, writer = await asyncio.open_unix_connection(path)
    sock = Socket(reader, writer)
    try:
        yield sock
    finally:
        await sock.close()


async def read(proc):
    lines = []
    while True:
        line = await proc.stdout.readline()
        line = line.decode().rstrip("\n")
        if not line:
            break
        lines.append(line)

        if proc.stdout.at_eof():
            break

    return lines


async def write(proc, input_, inclose=True):
    if isinstance(input_, str):
        input_ = [input_]

    if proc.stdin.is_closing():
        return False

    for line in input_:
        proc.stdin.write("{}\n".format(line).encode())

    await proc.stdin.drain()

    if inclose:
        proc.stdin.close()
        await proc.stdin.wait_closed()

    return True


def reformat_cmd(cmd):
    return cmd.replace("{", "{{").replace("}", "}}").replace("{{}}", "{}")


def menu(cmd, input_):
    cp = subprocess.run(cmd, input=input_, stdout=PIPE, text=True, shell=True)
    if cp.returncode != 0:
        return None
    return cp.stdout.strip()


async def menu_async(cmd):
    return await asyncio.create_subprocess_shell(
        cmd, stdin=ASYNC_PIPE, stdout=ASYNC_PIPE
    )


def prompt(input_, pinentry):
    pein = "setdesc {}\ngetpin\n".format(input_)
    cp = subprocess.run(pinentry, input=pein, capture_output=True, text=True)

    if cp.returncode != 0 or not cp.stdout:
        return None

    lines = cp.stdout.splitlines()
    dlines = [line for line in lines if line.startswith("D ")]
    if not dlines:
        return None

    passline = dlines[0]
    return passline.partition(" ")[-1]


async def communicate(sock, msg, pinentry):
    fullresp = []
    await sock.send(msg)
    while True:
        try:
            resp = await sock.recv()
        except EOFError:
            break
        except ConnectionResetError:
            eprint(
                "Connection reset by kpsh server - probably other "
                "client currently blocks it."
            )
            break

        if resp == "OK":
            break

        resptype, _, respmsg = resp.partition(" ")

        if resptype in ("M", "E"):
            fullresp.append(respmsg)
        elif resptype in ("P", "PS"):
            out = prompt(respmsg, pinentry)
            if out is None:
                return None
            await sock.send(out)
    return fullresp


def display_options(options, prefix):
    return collections.OrderedDict((prefix + cmd, cmd) for cmd in options)


def prepare_args():
    ap = argparse.ArgumentParser(
        description="Rofi/dmenu/fzf etc. access to KeePass database via kpsh. "
        "It is optimized for implementations which can read stdin "
        "in a non-blocking manner, e.g. rofi -async-pre-read=0 or "
        'dmenu with "non-blocking stdin" patch. This way menu is '
        "displayed very quickly and you can start typing "
        "immediately, while database is unlocked and listed in "
        "background."
    )
    ap.add_argument(
        "-m",
        "--menu",
        default='rofi -dmenu -async-pre-read 0 -i -p ">"',
        help="menu command to use. By default uses rofi.",
    )
    ap.add_argument(
        "-s",
        "--socket-path",
        default=default_sock_path(),
        help="Path to the socket to which kpsh-menu will connect.",
    )
    ap.add_argument(
        "-c",
        "--command",
        action="append",
        dest="commands",
        help="Set of kpsh commands from which one can be chosen "
        "and performed on a selected entry. This argument has "
        'special formatting: "DisplayName::Command::Flags". '
        "DisplayName is any human-friendly string to be "
        "displayed in menu. Command is any kpsh command and "
        'might contain a placeholder "{}" which will be '
        "replaced with selected entry path. Flags is a set of "
        "flags used to modify behavior of kpsh-menu after "
        "selected command is executed. -c argument can be "
        "more than once to set up several command choices. "
        'Default: "Autotype::autotype {}". Flags: n: send '
        "notification after action; l: choose another action "
        "after this one",
    )
    ap.add_argument(
        "-n",
        "--notify",
        help="run a system command after each kpsh's command with "
        '"n" flag. Accepts {entry} and {cmd} placeholders.',
    )
    ap.add_argument(
        "--pinentry",
        default="/usr/bin/pinentry",
        help="Command used to run pinentry when kpsh server " "prompts for user input.",
    )

    return ap.parse_args()


def gather_commands(args):
    commands = {}
    for cmdstr in args.commands:
        cmd = Command(cmdstr)
        commands[cmd.display] = cmd
    return commands


async def loop_states(args, sock):
    commands = gather_commands(args)

    state = "entry"
    entry = None
    command = None

    entry_menu = None
    while state:
        if state == "entry":
            co_out = communicate(sock, "ls", args.pinentry)
            entry_menu = await menu_async(args.menu)
            await write(entry_menu, await co_out)

            lines = await read(entry_menu)
            if not lines:
                break
            entry = lines[0]
            state = "actionchoice"
            await entry_menu.wait()

        if state == "actionchoice":
            if len(commands) > 1:
                choice = menu(args.menu, "\n".join(cmd for cmd in commands))
                if not choice:
                    state = "entry"
                    continue
                cmd = commands[choice]
            else:
                _, cmd = commands.popitem()

            command = cmd.command.format(entry)
            state = "action"

        if state == "action":
            out = await communicate(sock, command, args.pinentry)

            state = None
            if "l" in cmd.flags:
                state = "actionchoice"
            if "n" in cmd.flags:
                if not args.notify:
                    eprint("No notify program. Use --notify flag to set one.")
                else:
                    ncmd = args.notify.format(path=entry, cmd=command)
                    subprocess.run(ncmd, shell=True)

            if out:
                menu(args.menu, "\n".join(out))


async def main(args):
    if not args.commands:
        args.commands = ["Autotype::autotype {}"]

    try:
        async with connect(args.socket_path) as sock:
            await loop_states(args, sock)
    except (FileNotFoundError, ConnectionRefusedError):
        eprint(
            "Unable to connect to socket '{}' "
            "- is daemon running?".format(args.socket_path)
        )
        return 1


sys.exit(asyncio.run(main(prepare_args())))
