#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import print_function

import argparse
import asyncio
import configparser
import copy
import glob
import json
import os.path
import re
import shutil
import sys
import time
from collections import OrderedDict

from pkg_resources import Requirement, resource_filename

import psutil
from papavisor.aioxmlrpc_client import ProtocolError, ServerProxy
from texttable import Texttable

DESCRIPTION_REGEX = re.compile('^pid\s([\d]+),\suptime\s([\d:]+)')


def merge(a, b, path=None):
    """merges b into a
    http://stackoverflow.com/questions/7204805/dictionaries-of-dictionaries-merge
    """
    if path is None:
        path = []

    for key in b:
        if key in a:
            if (isinstance(a[key], (dict, OrderedDict)) and
                    isinstance(b[key], (dict, OrderedDict))):
                merge(a[key], b[key], path + [str(key)])
            elif a[key] == b[key]:
                pass  # same leaf value
            else:
                a[key] = b[key]
        else:
            a[key] = b[key]
    return a


class SupervisorCtl(object):

    def __init__(self, config):
        self._name = config['name']
        try:
            self._rpc = ServerProxy(
                config['url'],
                username=config['user'],
                password=config['password']
            )
        except OSError as e:
            raise OSError("%s, on URL %r" % (e, config['url']))

        self._rpcns = getattr(self._rpc, 'supervisor')

        self._config = config
        self._groups = {}
        self._programs = {}

    @asyncio.coroutine
    def _get_status(self):
        cfg_programs = self._config['programs']

        status = yield from self._rpcns.getAllProcessInfo()
        programs = {}
        for p in status:
            pn = p['name']
            if pn not in cfg_programs:
                print(
                    "%s\tERROR: Unknown service name %r" % (
                        self._name, pn
                    ),
                    file=sys.stderr
                )
                continue

            programs[pn] = p
            programs[pn]['startsecs'] = cfg_programs[pn].get('startsecs', 0)
            programs[pn]['tcp_check_port'] = cfg_programs[pn].get('tcp_check_port', 0)
            programs[pn]['tcp_check_host'] = cfg_programs[pn].get('tcp_check_host', "127.0.0.1")
            programs[pn]['tcp_check_timeout'] = cfg_programs[pn].get('tcp_check_timeout', 300)

        self._programs_by_type = self._group_programs_by_type(programs)

        if not self._groups:
            groups = {}
            for pn, pstatus in programs.items():
                if pn not in cfg_programs:
                    print(
                        "%s\tERROR: Unknown service name %r" % (
                            self._name, pn
                        ),
                        file=sys.stderr
                    )
                    continue

                sdefs = cfg_programs[pn]
                for sgroup in sdefs['groups']:
                    if sgroup not in groups:
                        groups[sgroup] = {}

                    groups[sgroup][pn] = sdefs['priority']

            for k, g in groups.items():
                self._groups[k] = OrderedDict(
                    sorted(g.items(), key=lambda p: p[1])
                )

        self._programs = programs

    def close(self):
        if self._rpc is not None:
            self._rpc.close()

    @asyncio.coroutine
    def stop(self, group_or_program):
        """Stop the given group name. """
        yield from self._get_status()

        if group_or_program in self._groups:
            yield from self._stop_group(group_or_program)
        elif group_or_program in self._programs:
            yield from self._stop_program(group_or_program)
        else:
            raise ValueError("Unknown group or program %r" % group_or_program)

    @asyncio.coroutine
    def _stop_program(self, p, force=False):
        print("Stop program %r from project %r:" % (p, self._name))
        if force or self._p_is_running(p):
            print("%s\tstop\t%s" % (self._name, p))
            yield from self._rpcns.stopProcess(p)

    @asyncio.coroutine
    def _stop_group(self, group):
        print("Stop group %r from project %r:" % (group, self._name))
        for p in self._groups[group]:
            yield from self._stop_program(p)

    @asyncio.coroutine
    def start(self, group_or_program):
        """Start the given group name. """
        yield from self._get_status()

        if group_or_program in self._groups:
            yield from self._start_group(group_or_program)
        elif group_or_program in self._programs:
            yield from self._start_program(group_or_program)
        else:
            raise ValueError("Unknown group or program %r" % group_or_program)

    @asyncio.coroutine
    def _start_group(self, group):
        print("Start group %r from project %r:" % (group, self._name))
        for p in reversed(self._groups[group]):
            yield from self._start_program(p)

    @asyncio.coroutine
    def _start_program(self, p, force=False, check_or_block=False):
        print("Start program %r from project %r:" % (p, self._name))
        if force or self._p_is_stopped(p):
            print("%s\tstart\t%s" % (self._name, p))
            yield from self._rpcns.startProcess(p)
            if check_or_block:
                pdefs = self._programs[p]
                st = pdefs.get('startsecs', 0)
                cp = pdefs.get('tcp_check_port', 0)
                if cp > 0:
                    print("%s\ttcp_check\t%s\t%d" % (self._name, p, cp))
                    try:
                        yield from self._tcp_check(p)
                    except OSError as e:
                        print("ERROR: %s" % e)
                if st > 0:
                    print("%s\tsleep\t%s\t%d" % (self._name, p, st))
                    yield from asyncio.sleep(st)

    @asyncio.coroutine
    def restart(self, group):
        """Restart the given group name."""
        yield from self._get_status()

        if group not in self._groups:
            raise ValueError("Unknown group %r" % group)

        start_extra = False
        if group in self._config.get('groups', {}):
            start_extra = self._config['groups'][group] \
                .get('start_extra', False)

        # check which programs are running and select them
        # for a restart.
        to_restart = []
        for p in self._groups[group].keys():
            if self._p_is_running(p):
                to_restart.append(p)

        if start_extra:
            print(
                ("Restart group %r from project %r"
                 " with options (start_extra: %r)") % (
                    group, self._name,
                    start_extra
                ))

            to_stop = []
            ptg = self._group_programs_by_type(to_restart)

            if start_extra:
                for t, t_programs in ptg.items():
                    if (len(t_programs) < 2 and
                            len(self._programs_by_type[t]) > 1):

                        avail = set(self._programs_by_type[t]) \
                            .difference(t_programs)
                        if avail:
                            additional = avail.pop()
                            to_stop.append(additional)
                            yield from self._start_program(
                                additional,
                                force=True,
                                check_or_block=True)

            # Stop selected programs
            for p in to_restart:
                if self._p_is_running(p):
                    yield from self._stop_program(p)

            # and start them in reversed order again.
            for p in reversed(to_restart):
                yield from self._start_program(
                    p,
                    force=True,
                    check_or_block=True
                )

            # now stop "temporary" instances
            for p in to_stop:
                yield from self._stop_program(p, force=True)

        else:
            print("Restart group %r from project %r:" % (group, self._name))

            # Stop selected programs
            for p in to_restart:
                if self._p_is_running(p):
                    yield from self._stop_program(p)

            # and start them in reversed order again.
            for p in reversed(to_restart):
                yield from self._start_program(
                    p,
                    force=True,
                    check_or_block=True
                )

    @asyncio.coroutine
    def status(self, group, texttable=None):
        yield from self._get_status()

        if group not in self._groups:
            raise ValueError("Unknown group %r" % group)

        for p in self._groups[group].keys():
            pdata = self._programs[p]

            pid = ''
            mem_rss = ''
            uptime = ''

            if self._p_is_running(p):
                desc_match = DESCRIPTION_REGEX.match(pdata['description'])
                if desc_match:
                    pid = desc_match.group(1)
                    uptime = desc_match.group(2)

                    try:
                        process = psutil.Process(int(pid))
                        mem_info = process.memory_info()
                        mem_rss = str(int(
                            mem_info[0] / 1024 / 1024)
                        ) + " MiB"
                    except psutil.NoSuchProcess:
                        pass

            if texttable is None:
                print("\t".join((
                    self._name,
                    p + ' ' * (15 - len(p)),
                    pdata['statename'],
                    pid,
                    uptime,
                    mem_rss,
                )))
            else:
                texttable.add_row([
                    self._name,
                    p,
                    pdata['statename'],
                    pid,
                    uptime,
                    mem_rss,
                ])

    def _p_is_running(self, p):
        return self._programs[p]['statename'] == 'RUNNING' or \
            self._programs[p]['statename'] == 'STARTING'

    def _p_is_stopped(self, p):
        return self._programs[p]['statename'] == 'STOPPED' or \
            self._programs[p]['statename'] == 'STOPPING'

    def _group_programs_by_type(self, programs):
        cfg_programs = self._config['programs']
        programs_by_type = {}
        for p in programs:
            p_type = cfg_programs[p]['type']
            if p_type not in programs_by_type:
                programs_by_type[p_type] = []
            programs_by_type[p_type].append(p)

        return programs_by_type

    @asyncio.coroutine
    def _tcp_check(self, program):
        pdefs = self._programs[program]
        host = pdefs.get('tcp_check_host')
        port = pdefs.get('tcp_check_port')
        timeout = pdefs.get('tcp_check_timeout')

        time_start = time.time()
        while True:
            try:
                yield from asyncio.open_connection(host, port)
                return
            except OSError:
                if time_start + timeout <= time.time():
                    raise OSError(
                        "Failed to connect to (%r, %r), timeout occurred." % (
                            host, port
                        )
                    )
                yield from asyncio.sleep(1)


def get_configobj_from_path(config_path):
    if not os.path.exists(config_path):
        return OrderedDict(), OrderedDict()

    config = OrderedDict()

    cfg_files = glob.glob(config_path + '/*.json')
    for cfg_file in cfg_files:
        with open(cfg_file, 'r') as fp:
            try:
                tmp_cfg = json.load(fp, object_pairs_hook=OrderedDict)
            except ValueError:
                print(cfg_file)
                raise

        merge(config, tmp_cfg)

    service_defaults = config['__defaults__']
    del config['__defaults__']

    # env CONFIG_FILES for apapavisor
    if 'CONFIG_FILES' in os.environ:  # and os.environ['CONFIG_FILES']:
        files = os.environ['CONFIG_FILES'].rstrip(';').split(';')
        try:
            d = OrderedDict(
                [f.split(':') for f in files]
            )
            for name, cfgfile in d.items():
                if name in config:
                    # Do not overwrite a manual configured entry.
                    continue

                parser = configparser.ConfigParser()
                parser.read(cfgfile)
                cfg_opts = {
                    'url': parser['supervisorctl']['serverurl'],
                    'user': parser['supervisorctl']['username'],
                    'password': parser['supervisorctl']['password']
                }

                config[name] = cfg_opts
        except ValueError:
            if os.environ['CONFIG_FILES'] == '':
                print(
                    "apapavisor hasn't found any projects.",
                    file=sys.stderr
                )
            else:
                print(
                    "ERROR: Wrong input form apapavisor: %r" % (
                        os.environ['CONFIG_FILES'],
                    ),
                    file=sys.stderr
                )

    return config, service_defaults


@asyncio.coroutine
def _run_task_async(config, action='restart', group='python', table=None):
    sctl = None
    try:
        sctl = SupervisorCtl(config)
    except OSError as e:
        print('ERROR: %s' % e)

    if sctl is not None:
        try:
            if action.lower() == 'restart':
                yield from sctl.restart(group)

            elif action.lower() == 'start':
                yield from sctl.start(group)

            elif action.lower() == 'stop':
                yield from sctl.stop(group)

            elif action.lower() == 'status':
                table.set_cols_align(["l", "l", "l", "l", "l", "r"])
                table.header([
                    'project',
                    'program',
                    'status',
                    'pid',
                    'uptime',
                    'mem rss'
                ])
                yield from sctl.status(group, table)

        except ProtocolError:
            print('ERROR: can\'t connect to project %r.' % (
                config['name']
            ))
        finally:
            sctl.close()


def main(config_path, project, action, group_or_program):
    # Read config files
    config, service_defaults = get_configobj_from_path(config_path)

    # Select projects
    to_run = config  # default all projects
    if project != 'all':
        projects = project.lower()
        to_run = {}
        for k in config.keys():
            # use startswith so users can provide "half" project names.
            if k.startswith(projects):
                to_run[k] = config[k]

    # print('Running action %r on group %r' % (action, group))

    texttable = Texttable()
    projects = []
    for pname, pconfig in to_run.items():

        if 'enabled' in pconfig and not pconfig['enabled']:
            continue

        project_config = copy.deepcopy(service_defaults)
        merge(project_config, pconfig)

        for tk, tv in project_config.get('types', {}).items():
            tk = tk.lower()
            for pk, pv in project_config['programs'].items():
                if pv['type'].lower() == tk:
                    for tvk, tvv in tv.items():
                        pv[tvk] = tvv

            del(project_config['types'])

        project_config['name'] = pname

        projects.append(asyncio.async(
            _run_task_async(
                project_config,
                action,
                group_or_program,
                texttable
            )
        ))

    if projects:
        loop = asyncio.get_event_loop()
        loop.run_until_complete(asyncio.wait(projects))
        loop.stop()

        if texttable._rows:
            print(texttable.draw() + "\n")

    else:
        print("Project(s) %r not found." % project, file=sys.stderr)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        'project',
        help="Project name, all or the startstring of a project",
        nargs='?',
        default="all"
    )

    parser.add_argument(
        'action',
        help="Action. status/start/stop/restart",
        nargs='?',
        default="status"
    )

    parser.add_argument(
        'group_or_program',
        help="Group or program to run the action on",
        nargs='?',
        default="all"
    )
    options = parser.parse_args()

    config_path = '/etc/papavisor'
    if not os.path.exists(config_path):
        # Copy config files on first start.
        config_path = os.path.join(
            os.path.expanduser('~/.config'), 'papavisor'
        )
        if not os.path.exists(config_path):
            os.makedirs(config_path, 0o700)

        files_to_copy = [
            '00_defaults.json', '01_template.json', 'apapavisor.sh'
        ]
        for f in files_to_copy:
            p = os.path.join(config_path, f)
            if not os.path.exists(p):
                shutil.copyfile(
                    resource_filename(
                        Requirement.parse('papavisor'),
                        'papavisor/config/%s' % f
                    ),
                    os.path.join(config_path, f)
                )

    main(
        config_path,
        options.project,
        options.action,
        options.group_or_program
    )
