#!/usr/bin/env python3


import sys
import os
import argparse
import asyncio
import logging
from logging.handlers import RotatingFileHandler
from pathlib import Path
from collections import defaultdict
import getpass

import uvloop

from suzieq.poller.nodes import init_hosts, init_files
from suzieq.poller.services import init_services

from suzieq.poller.writer import init_output_workers, run_output_worker
from suzieq.utils import load_sq_config


def validate_parquet_args(cfg, output_args, logger):
    """Validate user arguments for parquet output"""

    if not cfg.get("data-directory", None):
        output_dir = "/tmp/suzieq/parquet-out/"
        logger.warning(
            "No output directory for parquet specified, using" "/tmp/suzieq/parquet-out"
        )
    else:
        output_dir = cfg["data-directory"]

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    if not os.path.isdir(output_dir):
        logger.error(
            "Output directory {} is not a directory".format(output_dir))
        print("Output directory {} is not a directory".format(output_dir))
        sys.exit(1)

    logger.info("Parquet outputs will be under {}".format(output_dir))
    output_args.update({"output_dir": output_dir})

    return


def init_logger():
    """Initialize the logger"""

    # this needs to be suzieq.poller, so that it is the root of all the other pollers
    logger = logging.getLogger('suzieq.poller')
    logger.setLevel(cfg.get("logging-level", "WARNING").upper())
    fh = RotatingFileHandler(cfg.get("log-file", "/tmp/sq-poller.log"),
                             maxBytes=10000000, backupCount=2)
    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s " "- %(message)s"
    )
    fh.setFormatter(formatter)

    # set root logger level, so that we set asyncssh log level
    #  asynchssh sets it's level to the root level
    root = logging.getLogger()
    root.setLevel(cfg.get("logging-level", "WARNING").upper())
    root.addHandler(fh)

    logger.warning(f"log level {logging.getLevelName(logger.level)}")

    return logger


def send_terminate_to_threads(nodes: list, svcs: list, writers: list):
    """Set the terminate flag in all child threads"""
    for elem in nodes:
        elem.signend = True
    for elem in svcs:
        elem.sigend = True
    for elem in writers:
        elem.sigend = True


async def start_poller(userargs, cfg):

    logger = init_logger()

    if userargs.devices_file and userargs.namespace:
        logger.error("Cannot specify both -D and -n options")
        sys.exit(1)

    if not os.path.exists(cfg["service-directory"]):
        logger.error(
            "Service directory {} is not a directory".format(
                cfg['service-directory'])
        )
        print("Service directory {} is not a directory".format(
            cfg['service-directory']))
        sys.exit(1)

    if userargs.jump_host and not userargs.jump_host.startswith('//'):
        logger.error("Jump host format is //<username>@<jumphost>:<port>")
        print("Jump host format is //<username>@<jumphost>:<port>")
        sys.exit(1)

    if not cfg.get("schema-directory", None):
        schema_dir = "{}/{}".format(userargs.service_dir, "schema")
    else:
        schema_dir = cfg["schema-directory"]

    if userargs.ssh_config_file:
        ssh_config_file = os.path.expanduser(userargs.ssh_config_file)
        if (os.stat(
                os.path.dirname(ssh_config_file)).st_mode | 0o40700 != 0o40700):
            logger.error('ssh directory has wrong permissions, must be 0700')
            print('ERROR: ssh directory has wrong permissions, must be 0700')
            sys.exit(1)

    output_args = {}

    if "parquet" in userargs.outputs:
        validate_parquet_args(cfg, output_args, logger)

    if userargs.run_once:
        userargs.outputs = ["gather"]
        output_args['output_dir'] = userargs.output_dir

    outputs = init_output_workers(userargs.outputs, output_args)

    queue = asyncio.Queue()

    logger.info("Initializing hosts and services")
    tasks = []

    ignore_known_hosts = userargs.ignore_known_hosts or None

    if userargs.service_only:
        svclist = userargs.service_only.split()
    else:
        svcs = list(Path(cfg["service-directory"]).glob('*.yml'))
        svclist = [os.path.basename(x).split('.')[0] for x in svcs]

    if userargs.exclude_services:
        svclist = list(filter(lambda x: x not in userargs.exclude_services,
                              svclist))

    if not svclist:
        print(
            f"No correct services specified. Should have been one of {[svc.name for svc in svcs]}")
        sys.exit(1)

    if userargs.input_dir:
        tasks.append(init_files(userargs.input_dir))
    else:
        tasks.append(init_hosts(inventory=userargs.devices_file,
                                ans_inventory=userargs.ansible_file,
                                namespace=userargs.namespace,
                                passphrase=userargs.passphrase,
                                jump_host=userargs.jump_host,
                                jump_host_key_file=userargs.jump_host_key_file,
                                password=userargs.ask_pass,
                                ignore_known_hosts=ignore_known_hosts))

    period = cfg.get('period', 15)
    tasks.append(init_services(cfg["service-directory"], schema_dir, queue,
                               svclist, period,
                               userargs.run_once or "forever"))

    nodes, svcs = await asyncio.gather(*tasks)

    if not nodes or not svcs:
        # Logging should've been done by init_nodes/services for details
        print('Termminating because no nodes or services found')
        sys.exit(0)

    node_callq = defaultdict(lambda: defaultdict(dict))

    node_callq.update({x: {'hostname': nodes[x].hostname,
                           'postq': nodes[x].post_commands}
                       for x in nodes})

    for svc in svcs:
        svc.set_nodes(node_callq)

    logger.info("Suzieq Started")

    try:
        # The logic below of handling the writer worker task separately is to ensure
        # we can terminate properly when all the other tasks have finished as in the
        # case of using file input instead of SSH
        svc_tasks = [svc.run() for svc in svcs]
        tasks = [nodes[node].run() for node in nodes]
        tasks += [run_output_worker(queue, outputs)]
        tasks += svc_tasks

        while tasks:
            done, pending = await asyncio.wait(
                tasks, return_when=asyncio.FIRST_COMPLETED)
            tasks = list(pending)
            if tasks and any(i._coro in svc_tasks for i in tasks):
                continue
            else:
                break
        return

    except KeyboardInterrupt:
        logger.warning("sq-poller: Received terminate signal. Terminating")
        send_terminate_to_threads(nodes, svcs, outputs)
        await asyncio.gather(*tasks)
        raise

if __name__ == "__main__":

    homedir = str(Path.home())
    supported_outputs = ["parquet"]

    parser = argparse.ArgumentParser()
    requiredgrp = parser.add_mutually_exclusive_group(required=True)
    requiredgrp.add_argument(
        "-D",
        "--devices-file",
        type=str,
        help="File with URL of devices to gather data from",
    )
    requiredgrp.add_argument(
        "-a",
        "--ansible-file",
        type=str,
        help="Ansible inventory file of devices to gather data from",
    )
    requiredgrp.add_argument(
        "-i",
        "--input-dir",
        type=str,
        help="Directory where run-once=gather data is"
    )

    parser.add_argument(
        "-n",
        "--namespace",
        type=str, required='--ansible-file' in sys.argv or "-a" in sys.argv,
        help="Namespace to associate for the gathered data"
    )
    parser.add_argument(
        "-o",
        "--outputs",
        nargs="+",
        default=["parquet"],
        choices=supported_outputs,
        type=str,
        help="Output formats to write to: parquet. Use "
        "this option multiple times for more than one output",
    )
    parser.add_argument(
        "-s",
        "--service-only",
        type=str,
        help="Only run this space separated list of services",
    )

    parser.add_argument(
        "-x",
        "--exclude-services",
        type=str,
        help="Exclude running this space separated list of services",
    )

    parser.add_argument(
        "-c",
        "--config",
        type=str, help="alternate config file"
    )

    parser.add_argument(
        "--run-once",
        type=str,
        choices=["gather", "process"],
        help=argparse.SUPPRESS,
    )

    parser.add_argument(
        "--output-dir",
        type=str,
        default=f'{os.path.abspath(os.curdir)}/sqpoller-output',
        help=argparse.SUPPRESS,
    )

    parser.add_argument(
        "--ask-pass",
        default=False,
        action='store_true',
        help="prompt to enter password for login to devices",
    )
    parser.add_argument(
        "--passphrase",
        default=False,
        action='store_true',
        help="prompt to enter private key passphrase",
    )

    parser.add_argument(
        "-j",
        "--jump-host",
        default="",
        type=str,
        help="Jump Host via which to access the devices, IP addr or DNS hostname"
    )

    parser.add_argument(
        "-K",
        "--jump-host-key-file",
        default="",
        type=str,
        help="Key file to be used for jump host"
    )

    parser.add_argument(
        "-k",
        "--ignore-known-hosts",
        default=False,
        action='store_true',
        help="Ignore Known Hosts File",
    )

    parser.add_argument(
        "--ssh-config-file",
        type=str,
        default=None,
        help="Path to ssh config file to use. If not set, config file is not used"
    )

    userargs = parser.parse_args()

    if userargs.passphrase:
        userargs.passphrase = getpass.getpass(
            'Passphrase to decode private key file: ')
    else:
        userargs.passphrase = None

    if userargs.ask_pass:
        userargs.ask_pass = getpass.getpass(
            'Password to login to device: ')
    else:
        userargs.ask_pass = None

    uvloop.install()
    cfg = load_sq_config(config_file=userargs.config)
    if not cfg:
        print("Could not load config file, aborting")
        sys.exit(1)

    try:
        asyncio.run(start_poller(userargs, cfg))
    except KeyboardInterrupt:
        pass
