#!/usr/bin/env python
# Copyright European Organization for Nuclear Research (CERN) 2013
#
# Licensed under the Apache License, Version 2.0 (the "License");
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Authors:
# - Stefan Prenner, <stefan.prenner@cern.ch>, 2017
#

'''
Probe that checks for each RSE if it is possible to download data using the webdav door and an X509 certificate without any extensions.
'''

import re
import requests
from rucio.client.client import Client
import sys
import zlib


# Exit statuses
OK, WARNING, CRITICAL, UNKNOWN = 0, 1, 2, 3

index = 1


def set_browser_enabled(rse, browser_enabled):
    '''
    Sets a new rse attribute 'browser_enabled' to True/False.

    :param rse: The RSE name.
    :param browser_enabled: The new boolean value for the browser_enabled attribute.
    '''

    if 'browser_enabled' in c.list_rse_attributes(rse):
        c.delete_rse_attribute(rse, 'browser_enabled')
    c.add_rse_attribute(rse, 'browser_enabled', browser_enabled)


def verifyDownload(rse, response, checksum, success_list, error_list, wrong_checksum_list):
    '''
    Verifies the response byte sequence by comparing its adler32 hash with the stored checksum.

    :param rse: The RSE name.
    :param response: The Response object returned by the requests GET call.
    :param checksum: The correct checksum stored in the database.
    :param success_list: A list to keep track of all successful file downloads.
    :param error_list: A list to keep track of all errors except http errors during the GET request.
    :param wrong_checksum_list: A list to keep track of all files that produce an adler32 hash different to the one stored in the database.
    '''

    try:
        cont = response.content
        adler = zlib.adler32(cont, 1L)
        # backflip on 32bit
        if adler < 0:
            adler = adler + 2 ** 32
            print('Adler checksum: ' + str('%08x' % adler))
        if str('%08x' % adler) != checksum:
            wrong_checksum_list.append(rse + ' : checksum ' + str('%08x' % adler))
        else:
            success_list.append(str(rse))
            set_browser_enabled(rse, True)
            print('Checksum corrent!')
    except:
        e_type = sys.exc_info()[0]
        e_value = sys.exc_info()[1]
        e_traceback = sys.exc_info()[2]
        error_list.append(str(rse) + ': Error while verifying! ' + str(e_type) + ' ' + str(e_value) + ' ' + str(e_traceback))
        set_browser_enabled(rse, False)
        print('An error occurred while verifying download, see error list for details.')


if __name__ == "__main__":
    '''
    Iterates through all replicas storing the specified file and keeps track of occurring errors.
    Adds current rse to one of 6 lists depending on the result (success, http error, other error, skipped due to blacklisting, missing download link, wrong checksum/corrupted file).
    '''

    c = Client()
    file_scope = sys.argv[1]
    file_name = sys.argv[2]
    cert_location = sys.argv[3]
    r = c.list_replicas([{'scope': file_scope, 'name': file_name}], schemes=['davs'])
    error_list = []
    empty_list = []
    wrong_checksum_list = []
    success_list = []
    http_error_list = []
    skipped_list = []

    for replica in r:
        checksum = replica['adler32']
        rses = replica['rses']
        for rse in rses:
            p = c.get_protocols(str(rse), scheme='davs')  # skip rse if not available
            if p['availability_read'] is False:
                skipped_list.append(rse)
                continue
            tmp = rses[rse]
            try:
                link = tmp.pop()
                print(str(index) + ': ' + str(rse) + ' ...')
                link = link.replace('davs', 'https', 1)
                response = requests.get(link, cert=cert_location, verify=False)
                try:
                    if 'text/html' in response.headers.get('content-type'):
                        response_text = str(response.text)
                        number_length = 3
                        pattern = r"\D(\d{%d})\D" % number_length  # \D to avoid matching 4 digit (or more) numbers
                        http_error_list.append(str(rse) + ': ' + str(list(set(re.findall(pattern, response_text)))))  # conversion to set to delete duplicates, back to list to get rid of 'set' when printing
                    else:
                        verifyDownload(rse, response, checksum, success_list, error_list, wrong_checksum_list)
                except TypeError as te:
                    print('HTTP Header did not have content-type attribute. Attempting download...')
                    verifyDownload(rse, response, checksum, success_list, error_list, wrong_checksum_list)
            except IndexError as e:
                print(str(index) + ': ' + str(rse) + ': Link is empty.')
                empty_list.append(rse)
                set_browser_enabled(rse, False)
            except:
                e_type = sys.exc_info()[0]
                e_value = sys.exc_info()[1]
                e_traceback = sys.exc_info()[2]
                error_list.append(str(rse) + ': ' + str(e_type) + ' ' + str(e_value) + ' ' + str(e_traceback))
                set_browser_enabled(rse, False)
                print('An error occurred, see error list for details.')
            link = None
            index += 1
            print('Browser enabled for ' + str(rse) + ': ' + str(c.list_rse_attributes(rse).get('browser_enabled')))

    print('Empty links (' + str(len(empty_list)) + '): ' + str(empty_list))
    print('Http Error list (' + str(len(http_error_list)) + '): ' + str(http_error_list))
    print('Links of other errors (' + str(len(error_list)) + '): ' + str(error_list))
    print('List of wrong checksums (' + str(len(wrong_checksum_list)) + '): ' + str(wrong_checksum_list))
    print('Success (' + str(len(success_list)) + '): ' + str(success_list))
    print('Skipped RSEs (' + str(len(skipped_list)) + '): ' + str(skipped_list))
    sys.exit(OK)
