#!/usr/bin/env python
# Copyright European Organization for Nuclear Research (CERN) 2017
#
# Licensed under the Apache License, Version 2.0 (the "License");
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Authors:
# - Nicolo Magini, <Nicolo.Magini@cern.ch>, 2017

'''
Probe to check read/write with all protocols on target RSE with rucio client
'''

import argparse
import filecmp
import os
import re
import sys
import tempfile
import time
import uuid
from commands import getstatusoutput

from rucio.client.replicaclient import ReplicaClient
from rucio.client.rseclient import RSEClient
from rucio.common.config import config_get
from rucio.common.exception import RSENotFound, RSEProtocolNotSupported
from rucio.rse import rsemanager


def upload(rse, protocol, scope, lfn, no_register=False):
    """
    Use the rucio upload command to stage out the file.
    :param rse - destination RSE
    :param protocol - transfer protocol
    :param scope - file scope
    :param lfn - file lfn
    :param no_register - skip file registration and disable auto-deletion at lifetime expiration
    """

    if no_register:
        no_reg_option = '--no-register'
    else:
        no_reg_option = '--lifetime 3600'

    cmd = 'rucio upload %s --rse %s --protocol %s --scope %s %s' % (no_reg_option, rse,
                                                                    protocol, scope, lfn)
    print 'Will use the following command for stageOut:\n%s\n' % cmd
    status, out = getstatusoutput(cmd)
    print 'stageOut output:\n%s\n' % out

    if status:
        out = out.replace('\n', '')
        raise Exception('stageOut failed -- rucio upload did not succeed: %s' % out)


def download(rse, protocol, scope, lfn, dst, pfn=None):
    """
    Use the rucio download command to stage in the file.
    :param rse:   source RSE
    :param protocol: transfer protocol
    :param scope: file scope
    :param lfn:   file lfn
    :param dst:   local destination directory
    :param pfn:   file pfn (required when the downloaded file is not registered in Rucio
    """

    pfn_option = ''
    if pfn:
        pfn_option = '--pfn %s' % pfn

    cmd = 'rucio download --dir %s --rse %s --protocol %s %s %s:%s' % (dst,
                                                                       rse,
                                                                       protocol,
                                                                       pfn_option,
                                                                       scope,
                                                                       lfn)

    print 'Will use the following command for stageIn:\n%s\n' % cmd
    status, out = getstatusoutput(cmd)

    print 'stageIn output:\n%s\n' % out

    if status:
        out = out.replace('\n', '')
        raise Exception('stageIn failed -- rucio download did not succeed: %s' % out)


def main():
    '''
    Run the upload and download tests and report the result
    '''
    OK, WARNING, CRITICAL, UNKNOWN = 0, 1, 2, 3
    exitstatus = UNKNOWN

    parser = argparse.ArgumentParser()
    parser.add_argument('RSE', help='Target RSE to test')
    parser.add_argument('--no-register', '-n', action='store_true',
                        help='Do not register the uploaded file and disable auto-deletion')
    parser.add_argument('--color', '-c', action='store_true',
                        help='Enable color in rucio client output')
    group = parser.add_mutually_exclusive_group()
    group.add_argument('--download', '-d', metavar='DID', action='store',
                       default='user.mlassnig:user.mlassnig.pilot.test.single.hits',
                       help='Specify a fixed container DID for the download test')
    group.add_argument('--download-uploaded', '-du', action='store_true',
                       help='Use the same file in the upload and download tests')
    parser.add_argument('--testmode', '-t', action='store_true',
                        help='Run probe in standalone test mode outside nagios env')

    try:
        args = parser.parse_args()
    except SystemExit:
        sys.exit(UNKNOWN)

    if len(args.download.split(':')) != 2:
        print "Wrong format for DID %s - required format is scope:name" % args.download
        sys.exit(UNKNOWN)

    if not args.color:
        # Disable rucio client colors
        os.environ['RUCIO_LOGGING_FORMAT'] = '%(asctime)s %(levelname)s [%(message)s]'

    rse = args.RSE
    if not re.match('^[A-Za-z0-9_-]+$', rse):
        print 'Invalid RSE name %s' % rse
        sys.exit(UNKNOWN)

    print 'Testing RSE:\n%s\n' % rse

    scope = 'tests'

    timestamp = '%d' % time.time()
    fileuuid = str(uuid.uuid1())

    reportfile = '%s/nagios-check_ruciomover.%s.%s.csv' % (os.getcwd(), rse, timestamp)

    workdir = tempfile.mkdtemp(prefix='%s.%s.' % (rse, timestamp))
    os.chdir(workdir)

    print 'Running in working directory:\n%s\n' % os.getcwd()

    if not args.testmode:
        try:
            proxy = config_get('nagios', 'proxy')
            os.environ["X509_USER_PROXY"] = proxy
        except Exception as e:
            print "Failed to get proxy from rucio.cfg"
            sys.exit(UNKNOWN)

    try:
        rse_settings = rsemanager.get_rse_info(rse)

        cli = RSEClient()
        rse_attribs = cli.list_rse_attributes(rse)

        if rse_attribs['istape'] in (True, 'True'):
            print 'Skipping sitemover tests on tape RSE'
            sys.exit(UNKNOWN)

        rse_protocols = rse_settings['protocols']

        # Check if the DID selected for the download test is present on the target RSE
        skip_download = False
        if not args.download_uploaded:
            repcli = ReplicaClient()
            reps = [r for r in repcli.list_replicas([{'scope': args.download.split(':')[0],
                                                      'name': args.download.split(':')[1]}])]
            if not reps or rse not in reps[0]['rses'].keys():
                skip_download = True

        report = {}

        for prot in rse_protocols:

            # Run tests only if the RSE supports read/write for this protocol.
            # Need to check for wan since the nagios probe is running remotely.
            if not (prot['domains']['wan']['read'] and prot['domains']['wan']['write']):
                continue

            print 'Testing protocol:\n%s\n' % prot['scheme']

            report[prot['scheme']] = {'upload': 'UNKNOWN', 'download': 'UNKNOWN'}

            filename = '%s.nagios-check_ruciomover.%s.%s.%s.%s' % (scope, rse, prot['scheme'],
                                                                   timestamp, fileuuid)
            print 'Will use the following test file for upload:\n%s\n' % filename

            with open(filename, 'w') as fil:
                fil.write('test file for %s with %s at time %s\n' % (rse,
                                                                     prot['scheme'], timestamp))
                fil.close()

            if rse_settings['deterministic']:
                pfn = rsemanager.lfns2pfns(rse_settings, {'scope': scope, 'name': filename},
                                           scheme=prot['scheme'])['%s:%s' % (scope, filename)]
                print 'Rucio will use the following deterministic PFN for upload:\n%s\n' % pfn
            else:
                print 'Cannot predetermine pfn on non-deterministic storage\n'
                if args.no_register:
                    print 'Skipping test'
                    os.unlink(filename)
                    continue

            try:
                upload(rse, prot['scheme'], scope, filename, no_register=args.no_register)
                report[prot['scheme']]['upload'] = 'OK'
            except Exception as e:
                print '\n%s\n' % e
                report[prot['scheme']]['upload'] = 'CRITICAL'
                if args.download_uploaded:
                    os.unlink(filename)
                    continue

            dst = 'download/%s.%s.%s' % (rse, prot['scheme'], timestamp)

            pfn_to_download = None

            if args.download_uploaded:
                download_scope = scope
                did_to_download = filename
                destination = '%s/%s/%s' % (dst, download_scope, did_to_download)
                # The PFN needs to be specified explicitly when downloading an unregistered file
                if args.no_register:
                    pfn_to_download = pfn
            else:
                if skip_download:
                    print "Skipping download test because selected DID %s was not found on RSE %s\n" % (args.download, rse)
                    os.unlink(filename)
                    continue
                download_scope = args.download.split(':')[0]
                did_to_download = args.download.split(':')[1]
                destination = '%s/%s' % (dst, did_to_download)

            try:
                download(rse, prot['scheme'], download_scope,
                         did_to_download, dst, pfn=pfn_to_download)
                report[prot['scheme']]['download'] = 'OK'
            except Exception as e:
                print '\n%s\n' % e
                report[prot['scheme']]['download'] = 'CRITICAL'

            if args.download_uploaded:
                if not filecmp.cmp(filename, destination):
                    print 'WARNING - Uploaded %s and downloaded %s files differ\n' % (filename, destination)
                    report[prot['scheme']]['download'] = 'WARNING'
                else:
                    print 'Test successful: uploaded %s and downloaded %s files match\n' % (filename, destination)

            os.unlink(filename)
            if os.path.isfile(destination):
                os.unlink(destination)
                os.removedirs(os.path.dirname(destination))
            elif os.path.isdir(destination):
                for fil in os.listdir(destination):
                    os.unlink('%s/%s' % (destination, fil))
                os.removedirs(destination)

        os.rmdir(workdir)

        print 'Writing report file to:\n%s\n' % reportfile
        with open(reportfile, 'w') as fil:
            for proto in report:
                for act in report[proto]:
                    fil.write('%s,%s,%s,%s,%s\n' % (rse, act, proto,
                                                    report[proto][act], timestamp))
            fil.close()

        # Calculate overall exit status per-RSE as following:
        # 1) Ignore UNKNOWN test results
        # 2) If all test results are CRITICAL, the RSE status is CRITICAL
        # 3) If at least one test result is WARNING or CRITICAL, the RSE status is WARNING
        # 4) If all test results are OK, the RSE status is OK

        statuses = [status for res in [d.values() for d in report.values()]
                    for status in res if status != 'UNKNOWN']

        if not statuses:
            exitstatus = UNKNOWN
        elif all([status == 'CRITICAL' for status in statuses]):
            exitstatus = CRITICAL
        elif 'WARNING' in statuses or 'CRITICAL' in statuses:
            exitstatus = WARNING
        elif all([status == 'OK' for status in statuses]):
            exitstatus = OK

    except RSEProtocolNotSupported as e:
        print e
        sys.exit(CRITICAL)
    except RSENotFound as e:
        print e
        sys.exit(UNKNOWN)

    print 'Probe ending for RSE %s with status %s' % (rse, exitstatus)
    sys.exit(exitstatus)


if __name__ == "__main__":
    main()
