#! /usr/bin/python3
import codecs, os, ssl, sys, sqlite3, urllib.request
from datetime import datetime, timedelta
from getopt import gnu_getopt, GetoptError
from html.parser import HTMLParser

# This probe is relying on the formatting of the dCache Dataset Restore Monitor
# page and may break in future versions of dCache.  Please let me know if there
# is a more machine friendly source for the information on that page.

# Check whether a PNFS id has been in restore for more than 12 h.


# Parser for /poolInfo/restoreHandler/*
# =====================================

class ParseError(Exception):
    pass

verbose = False

INIT, TABLE, THEAD, THEAD_TR, THEAD_TD, TBODY, TBODY_TR, TBODY_TD = range(8)

start_transitions = {
    (INIT, 'table'):  TABLE,
    (TABLE, 'thead'): THEAD,
    (THEAD, 'tr'):    THEAD_TR,
    (THEAD_TR, 'th'): THEAD_TD,
    (TABLE, 'tbody'): TBODY,
    (TBODY, 'tr'):    TBODY_TR,
    (TBODY_TR, 'td'): TBODY_TD,
}
end_transitions = \
    dict((source, (tag, target))
         for (target, tag), source in start_transitions.items())
tag_of_state = \
    dict((state, tag)
         for (_, tag), state in start_transitions.items())

class RestoreParser(HTMLParser):
    def __init__(self):
        HTMLParser.__init__(self)
        self.result = None
        self._state = INIT
        self._hdr = None
        self._row = None
        self._data = None

    def handle_starttag(self, tag, attrs):
        if (self._state, tag) in start_transitions:
            self.result = []
            self._state = start_transitions[(self._state, tag)]
            if self._state == THEAD_TR:
                self._hdr = []
            elif self._state == TBODY_TR:
                self._row = {}
            elif self._state == THEAD_TD or self._state == TBODY_TD:
                self._data = ""
        elif self._state != INIT:
            raise ParseError('Unexpected tag %s in %s.'
                             % (tag, tag_of_state[self._state]))

    def handle_endtag(self, tag):
        if self._state != INIT:
            if self._state == TBODY_TR:
                self.result.append(self._row)
                self._row = None
            elif self._state == THEAD_TD:
                self._hdr.append(self._data)
                self._data = None
            elif self._state == TBODY_TD:
                self._row[self._hdr[len(self._row)]] = self._data
                self._data = None
            expected_tag, self._state = end_transitions[self._state]
            assert(tag == expected_tag)

    def handle_data(self, data):
        if not self._data is None:
            self._data += data

# State
# =====

top_statedir = os.getenv('NAGIOS_PLUGIN_STATE_DIR', '/var/spool/nagios/plugins')
db_path = os.path.join(top_statedir, 'check_dcache_restore.db')
db_schema = """
CREATE TABLE active_restore (
    pnfs_id text PRIMARY KEY,
    first_seen timestamp NOT NULL DEFAULT current_timestamp,
    last_seen timestamp NOT NULL DEFAULT current_timestamp
);
"""
db_insert = "INSERT INTO active_restore (pnfs_id) VALUES (?)"
db_update = "UPDATE active_restore SET last_seen = current_timestamp "\
            "WHERE pnfs_id = ?"
db_select = "SELECT pnfs_id, first_seen, last_seen FROM active_restore"
db_prune = "DELETE FROM active_restore "\
           "WHERE last_seen < datetime('now', '-24 hours')"

class RestoreState(object):
    def __init__(self):
        db_existed = os.path.exists(db_path)
        self._dbconn = sqlite3.connect(db_path)
        self._to_update = []
        self._to_insert = []
        self.data = {}
        self.now = datetime.utcnow()
        if not db_existed:
            self._dbconn.execute(db_schema)
        else:
            cur = self._dbconn.cursor()
            cur.execute(db_prune)
            cur.execute(db_select)
            for pnfs_id, first_seen_str, last_seen_str in cur:
                first_seen = datetime.strptime(first_seen_str, '%Y-%m-%d %H:%M:%S')
                last_seen = datetime.strptime(last_seen_str, '%Y-%m-%d %H:%M:%S')
                self.data[pnfs_id] = (first_seen, last_seen)
            cur.close()

    def add(self, pnfs_id):
        if pnfs_id in self.data:
            self._to_update.append(pnfs_id)
            self.data[pnfs_id] = (self.data[pnfs_id][0], self.now)
        else:
            self._to_insert.append(pnfs_id)
            self.data[pnfs_id] = (self.now, self.now)

    def save(self):
        if self._to_update != []:
            cur = self._dbconn.cursor()
            for pnfs_id in self._to_update:
                cur.execute(db_update, [pnfs_id])
            cur.close()
        if self._to_insert != []:
            cur = self._dbconn.cursor()
            for pnfs_id in self._to_insert:
                cur.execute(db_insert, [pnfs_id])
            cur.close()
        self._dbconn.commit()

    def close(self):
        self._dbconn.close()

def str_of_timestuff(timestuff):
    return str(timestuff).split('.', 1)[0]

def check(url, certfile = None, keyfile = None,
          critical_limit = None, warning_limit = None):

    # Fetch restore information from dCache
    if certfile is None:
        context = None
    else:
        context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
        context.load_cert_chain(certfile = certfile, keyfile = keyfile)
    decode = codecs.getreader('utf-8')
    fh = decode(urllib.request.urlopen(url, context = context))
    parser = RestoreParser()
    try:
        while True:
            buf = fh.read(2048)
            if buf == '':
                break
            parser.feed(buf)
    except ParseError as err:
        fh = decode(urllib.request.urlopen(url, context = context))
        print('Failed to parse page: %s' % err)
        ofn = 'check_dcache_restore-crash-%s.html' % \
            datetime.strftime(datetime.now(), '%Y-%m-%dT%H%M')
        ofp = os.path.join(top_statedir, ofn)
        ofh = open(ofp, 'w')
        print('Re-fetching and saving content to %s.' % ofp)
        for ln in fh:
            ofh.write(ln)
        sys.exit(3)

    if parser.result is None:
        print('No restore data found.')
        print('URL: %s' % url)
        sys.exit(3)

    # Update the database
    restore_state = RestoreState()
    for restore in parser.result:
        restore_state.add(restore['PnfsId'])
    restore_state.save()

    # Check and report
    critical_count = 0
    warning_count = 0
    longest = timedelta()
    longest_recent = timedelta()
    for pnfs_id, (first_seen, last_seen) in restore_state.data.items():
        duration = last_seen - first_seen
        if last_seen == restore_state.now:
            if duration > critical_limit:
                critical_count += 1
            elif duration > warning_limit:
                warning_count += 1
            longest = max(duration, longest)
        longest_recent = max(duration, longest_recent)
    status_msgs = []
    if critical_count > 0:
        status_msgs.append('%d restores taking longer than %s' %
                           (critical_count, critical_limit))
        if warning_count > 0:
            status_msgs.append('another %d taking longer than %s' %
                               (warning_count, warning_limit))
    elif warning_count > 0:
        status_msgs.append('%d restores taking longer than %s' %
                           (warning_count, warning_limit))
    status_msgs.append('longest ongoing restore %s' % str_of_timestuff(longest))
    print(', '.join(status_msgs).capitalize())
    print('Longest recent restore: %s' % str_of_timestuff(longest_recent))
    if verbose:
        print('Recent restores and UTC times when seen by the probe:')
        for pnfs_id, (first_seen, last_seen) in restore_state.data.items():
            duration = last_seen - first_seen
            print('%s seen for %s (%s to %s)' %
                  (pnfs_id,
                   str_of_timestuff(duration),
                   str_of_timestuff(first_seen),
                   str_of_timestuff(last_seen)))
    if critical_count > 0: sys.exit(2)
    if warning_count > 0: sys.exit(1)
    restore_state.close()
    sys.exit(0)

url = None
keyfile = None
certfile = None
critical_limit = timedelta(0, 12 * 3600)
warning_limit = timedelta(0, 8 * 3600)

def parse_timedelta(s):
    if s.endswith('h'):
        return timedelta(0, int(s[:-1]) * 3600)
    elif s.endswith('s'):
        return timedelta(0, int(s[:-1]))
    else:
        return timedelta(0, int(s))

try:
    opts, args = gnu_getopt(sys.argv[1:], 'U:c:w:', ['certkey=', 'cert='])
    for opt, arg in opts:
        if opt == '-c':
            critical_limit = parse_timedelta(arg)
        elif opt == '-w':
            warning_limit = parse_timedelta(arg)
        elif opt == '-U':
            url = arg
        elif opt == '--certkey':
            keyfile = arg
        elif opt == '--cert':
            certfile = arg
        else:
            assert False
    if url is None:
        raise GetoptError('Missing mandatory -U option.')
except GetoptError as err:
    sys.stderr.write('%s\n' % err)
    sys.exit(3)

check(url, certfile = certfile, keyfile = keyfile,
      critical_limit = critical_limit, warning_limit = warning_limit)
