#!/usr/bin/env python3
# cython: language_level=3

import os
import sys
import json
import time
import socket
import signal
import argparse
import textwrap
from subprocess import check_call
from threading import Thread, get_native_id as pid
from queue import Queue

version = "1.1.7"
info = (
    "Erebor (%s)\n"
    "    set      <table> <key> <value>  set a key in a table to a value\n"
    "    del      <table> <key>          remove a key and its value from a table\n"
    "    get      <table> <key>          retrieve the value of a key from a table\n"
    "    dump     <table>                return all keys and values for a table in JSON\n"
    "    pwd                             return the filesystem location of the tables\n"
    "    ls                              return a list of all table names\n"
    "    version                         return the current erebor version\n"
    "    help                            display this helpful information" % (version)
)
port = 8044
host = "127.0.0.1"
default_storage_directory = ".erebor"
daemon_storage_directory = "/var/lib/erebor"


def daemon(host, port, path="/var/lib/erebor"):
    service = """[Unit]
    Description=persistent key-value store
    After=network.target
    StartLimitIntervalSec=0
    [Service]
    Type=simple
    Restart=always
    RestartSec=1
    User=root
    ExecStart=%s --host %s --port %s --path=%s
    StandardOutput=append:/var/log/erebor
    StandardError=append:/var/log/erebor

    [Install]
    WantedBy=multi-user.target"""
    try:
        check_call(
            "systemctl stop erebor".split(),
            stdout=open(os.devnull, "wb"),
            stderr=open(os.devnull, "wb"),
        )
    except:
        pass
    try:
        with open("/etc/systemd/system/erebor.service", "w") as f:
            binary = os.path.dirname(os.path.abspath(__file__)) + "/erebor"
            f.writelines(service % (binary, host, port, path))
        check_call(
            "systemctl daemon-reload".split(),
            stdout=open(os.devnull, "wb"),
            stderr=open(os.devnull, "wb"),
        )
        check_call(
            "systemctl enable erebor".split(),
            stdout=open(os.devnull, "wb"),
            stderr=open(os.devnull, "wb"),
        )
        check_call(
            "systemctl start erebor".split(),
            stdout=open(os.devnull, "wb"),
            stderr=open(os.devnull, "wb"),
        )
    except:
        pass


def persist(identifier, storage_directory=default_storage_directory, *args, **kwargs):
    return PersistentDictionary(identifier, storage_directory, *args, **kwargs)


def erebor(host, port, storage_directory):
    return Erebor(host=host, port=port, storage_directory=storage_directory)


class PersistentDictionary(dict):
    def __init__(self, identifier, storage_directory, *args, **kwargs):
        super(PersistentDictionary, self).__init__(*args, **kwargs)
        self.identifier = identifier
        self.storage_directory = storage_directory
        self.load()

    def load(self):
        if os.path.exists(self.location()):
            with open(self.location(), "r") as f:
                self.update(json.load(f))

    def save(self):
        if not os.path.exists(self.storage_directory):
            os.makedirs(self.storage_directory)
        try:
            data = json.dumps(
                self,
                indent=4,
                sort_keys=True,
            )

        except BaseException:
            raise Exception("Data could not be encoded to JSON")

        with open(self.location(), "w") as f:
            f.write(data)

    def drop(self):
        os.remove(self.location())

    def location(self):
        return os.path.join(
            self.storage_directory,
            "%s" % (self.identifier,),
        )


class Erebor:
    def __init__(
        self,
        host=host,
        port=port,
        storage_directory=default_storage_directory,
    ):
        self.host = host
        self.port = port
        self.version = version
        self.info = info
        self.storage_directory = storage_directory
        if not os.path.exists(self.storage_directory):
            os.makedirs(self.storage_directory)
        self.sock = self.erebor_socket(host, port)
        # Set number of workers to number of usable CPUs
        self.workers = len(os.sched_getaffinity(0))
        self.done = False
        self.jobs = Queue()

    def log(self, s):
        date = time.strftime("%Y-%m-%d %H:%M:%S")
        print("[{}] [{}] {}".format(date, os.getpid(), s))

    def erebor_socket(self, host, port):
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        sock.bind((host, port))
        sock.listen()
        return sock

    def listen(self):
        self.log("Listening at: %s:%s" % (self.host, str(self.port)))
        for worker in range(self.workers):
            Thread(target=self.handle_connection, args=(self.jobs, worker)).start()
        while True:
            try:
                connection, client_address = self.sock.accept()
                # self.log("Connection from %s" % str(client_address))
                self.jobs.put(connection)
            except:
                print("\nShutting down...")
                self.done = True
                self.jobs.join()
                sys.exit(0)

    def handle_connection(self, jobs, worker):
        # Restrict thread to a specific CPU core. We use the formula
        # (worker number mod # cores) to ensure an even distribution.
        os.sched_setaffinity(0, {worker % len(os.sched_getaffinity(0))})
        # self.log(
        #    "Booting worker with pid %d on CPU core %s"
        #    % (pid(), os.sched_getaffinity(0))
        # )
        while True:
            try:
                connection = jobs.get(timeout=1)
                connection.settimeout(60)
                while True:
                    try:
                        data = connection.recv(65535)
                    except:
                        connection.close()
                        jobs.task_done()
                        break
                    if not data:
                        connection.close()
                        jobs.task_done()
                        break
                    req = data.decode("utf-8").strip()
                    res = self.parse(req) + "\n"
                    connection.sendall(res.encode("utf-8"))
            except:
                if self.done:
                    # self.log("Worker %d exiting..." % pid())
                    return

    def dump(self, req):
        table = req[1]
        if ".." in table:
            return "(error): table name invalid"
        d = persist(table, storage_directory=self.storage_directory)
        if len(d) == 0:
            return "(error): table not found"
        return json.dumps(d)

    def set(self, req, data):
        command, table, key, value = (
            req.pop(0),
            req.pop(0),
            req.pop(0),
            " ".join(data.split(" ", 3)[3:]),
        )
        if value == "":
            return "(error): no value specified"
        if ".." in table:
            return "(error): table name invalid"
        d = persist(table, storage_directory=self.storage_directory)
        d[key] = value
        d.save()
        return "OK"

    def get(self, req):
        table = req[1]
        if ".." in table:
            return "(error): table name invalid"
        key = req[2]
        d = persist(table, storage_directory=self.storage_directory)
        try:
            return d[key]
        except:
            return "(error): key not found"

    def delete(self, req):
        table = req[1]
        if ".." in table:
            return "(error): table name invalid"
        key = req[2]
        d = persist(table, storage_directory=self.storage_directory)
        try:
            d.pop(key)
        except:
            return "(error): key not found"
        if len(d) > 0:
            d.save()
        else:
            d.drop()
        return "OK"

    def parse(self, data):
        req = data.split(" ")
        command = req[0]
        argc = len(req)
        if command == "help":
            return self.info
        elif command == "version":
            return self.version
        elif command == "ls":
            return str(os.listdir(self.storage_directory))
        elif command == "pwd":
            return os.path.abspath(self.storage_directory)
        elif command == "dump":
            if argc != 2:
                return "(error): invalid syntax"
            return self.dump(req)
        elif command == "get":
            if argc != 3:
                return "(error): invalid syntax"
            return self.get(req)
        elif command == "set":
            if argc < 3:
                return "(error): invalid syntax"
            return self.set(req, data)
        elif command == "del":
            return self.delete(req)
        else:
            return "(error): invalid command"


def boot(args):
    if args.daemon == True:
        if args.path == default_storage_directory:
            args.path = daemon_storage_directory
        daemon(host=args.host, port=args.port, path=args.path)
    else:
        e = erebor(host=args.host, port=args.port, storage_directory=args.path)
        t = """
     ___
    /\  \ 
   /::\  \       Erebor %s
  /:/\:\  \ 
 /:/  \:\  \ 
/:/__/ \:\__\    Running in standalone mode
\:\  \ /:/  /    Port: %-10s 
 \:\  /:/  /     PID:  %-10s 
  \:\/:/  / 
   \::/  /             https://pypi.org/project/erebor 
    \/__/ 
        """ % (
            e.version,
            args.port,
            os.getpid(),
        )
        print(t)
        e.listen()


# Main
if __name__ == "__main__":
    p = argparse.ArgumentParser(
        epilog=textwrap.dedent(info),
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    p.add_argument("--host", type=str, default=host)
    p.add_argument("--port", type=int, default=port)
    p.add_argument("--path", type=str, default=default_storage_directory)
    p.add_argument("--daemon", action="store_true")
    args = p.parse_args()

    boot(args)
