#!/usr/bin/env python
# Copyright European Organization for Nuclear Research (CERN)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#                       http://www.apache.org/licenses/LICENSE-2.0
#
# Authors:
# - David Cameron, <david.cameron@cern.ch>, 2015
# - Vincent Garonne, <vincent.garonne@cern.ch>, 2015

#
# Set quotas for physgroups
# Set quotas for users on their country's localgroupdisk
# Set quotas for all users on scratchdisk
# In the last two cases quotas are set only if accounts are new or if RSE
# capacity has changed

import json
import re
import requests
import sys
import traceback

from rucio.api import rse
from rucio.api.account import list_accounts, list_account_attributes
from rucio.api.account_limit import get_account_usage, get_account_limit, set_account_limit
from rucio.common.exception import RucioException, RSENotFound, AccountNotFound

UNKNOWN = 3
CRITICAL = 2
WARNING = 1
OK = 0

result = OK

# Map of country to list of accounts
country_accounts = {}
country_accounts_history = {}

# List of all user accounts
all_accounts = []
# Accounts already handled
done_accounts = []
# Map of RSE capacity
rse_capacity = {}

# Cache file of accounts
accounts_cache = '/var/cache/nagios/rucio_quotas_accounts.json'
# Cache file of RSE capacity
rse_capacity_cache = '/var/cache/nagios/rucio_quotas_rse_capacity.json'
# Cache file of accounts in given country
country_accounts_cache = '/var/cache/nagios/rucio_quotas_country_accounts.json'


def get_srm_total(name, ddm_type):
    global result
    try:
        capacity = rse.get_rse_usage(name, 'root', source='srm')
        tier = rse.list_rse_attributes(name)['tier']
    except RucioException as e:
        if ddm_type == 'TEST':
            print "Ignoring failed 'get capacity' for TEST endpoint: %s: %s" % (name, str(e))
        else:
            print "WARNING: Could not get capacity of %s: %s" % (name, str(e))
            result = WARNING
        return None

    try:
        srmtotal = capacity[0]['total']
    except:
        try:
            capacity = rse.get_rse_usage(name, 'root', source='gsiftp')
            srmtotal = capacity[0]['total']
        except:
            if tier in ['0', '1', '2'] and ddm_type != 'TEST':
                print "WARNING: No srm info for %s" % name
                result = WARNING
        return None

    return srmtotal


def set_group_quotas(quotas):
    ''' Set quota for physics groups '''

    global result
    # Check if quota is defined
    if not quotas:
        print "WARNING: No quota defined for %s" % name
        result = WARNING
        return

    for quota, value in quotas.iteritems():

        physgroup = re.match('/atlas/(.*)/', quota)
        if not physgroup:
            continue
        physgroup = physgroup.group(1)

        # in AGIS quota 999999 means no limit, so set infinite
        if value == 999999:
            size = -1
        else:
            size = value * 1024**4  # AGIS numbers are in TiB

        try:
            usage = rse.get_rse_usage(name, 'root', source='rucio')
            account_usage = get_account_usage(physgroup, name, 'root')
            try:
                account_used = account_usage[0]['bytes']
            except:
                account_used = 0
            print "Set quota for %s on RSE %s to %sTB (used RSE %dTB, used account %dTB)" % \
                (physgroup, name, 'inf ' if size == -1 else str(size / 1000**4), usage[0]['used'] / 1000**4, account_used / 1000**4)
            set_account_limit(physgroup, name, size, 'root')
        except AccountNotFound:
            print "Account %s does not exist" % physgroup
        except RucioException as e:
            print 'WARNING: %s' % str(e)
            result = WARNING
        except Exception as e:
            print 'CRITICAL: %s' % traceback.format_exc()
            print quota, value, physgroup, name
            result = CRITICAL


def set_user_quotas(name, accs, fraction, ddm_type, overwrite=True):
    ''' Set quota of fraction of total capacity for all accounts on RSE if
        space changed, or for new accounts '''

    global result
    # Check space compared to cache value
    srmtotal = get_srm_total(name, ddm_type)
    if not srmtotal:
        return

    accounts = accs

    if name not in rse_capacity:
        print "%s: no cached info" % name
    elif abs(rse_capacity[name] - srmtotal) > 1000**4:
        # Changes of <1TB are not important enough to change all the quotas
        print "%s capacity changed from %dTB to %dTB" % (name, rse_capacity[name] / 1000**4, srmtotal / 1000**4)
    else:
        print "%s total capacity is unchanged since last run" % name
        # Just change new accounts
        accounts = [a for a in accounts if a not in done_accounts]

    rse_capacity[name] = srmtotal
    for account in accounts:
        try:
            if not overwrite:
                value = get_account_limit(account=account, rse=name)
                if value and value[name] > 0:
                    print "Positive quota already set for %s to %s on %s" % (account, name, value)
                    continue
            print "Set quota of %dTB on %s for %s" % (srmtotal * fraction / 1000**4, name, account)
            set_account_limit(account, name, srmtotal * fraction, 'root')
        except RucioException as e:
            if ddm_type == 'TEST':
                print "%s failed, but ignoring it as TEST endpoint" % str(e)
            else:
                print 'WARNING: %s' % str(e)
                result = WARNING


# Takes DDM endpoint quota information from AGIS and sets rucio account quotas
if __name__ == '__main__':
    try:
        url = 'http://atlas-agis-api.cern.ch/request/ddmendpoint/query/list/?json&state=ACTIVE&site_state=ACTIVE'
        try:
            resp = requests.get(url=url)
            data = json.loads(resp.content)
        except Exception as e:
            print "Failed to load info from AGIS: %s" % str(e)
            sys.exit(CRITICAL)

        # Read capacity information from cached file
        try:
            with open(rse_capacity_cache) as f:
                rse_capacity = json.load(f)
        except:
            print "No cached RSE capacity info, will set all quotas"

        # Read accounts from cached file
        try:
            with open(accounts_cache) as f:
                done_accounts = json.load(f)
        except:
            print "No cached accounts, will set quotas for all accounts"

        # Read accounts from cached file
        try:
            with open(country_accounts_cache) as f:
                country_accounts_history = json.load(f)
        except:
            print "No cached countries with accounts, will set quotas for all accounts"

        # Get all rucio accounts and country attributes
        try:
            accounts = list_accounts()
            all_accounts = [a['account'] for a in accounts]
            for account in all_accounts[:]:
                attrs = list_account_attributes(account)
                for attr in attrs:
                    if attr['key'] == 'admin' and attr['value'] == 1:
                        all_accounts.remove(account)
                        continue
                    if attr['key'].startswith('country-'):
                        country = attr['key'].split('-')[1]
                        try:
                            country_accounts[country].append(account)
                        except:
                            country_accounts[country] = [account]
        except RucioException as e:
            print "Failed to list Rucio accounts or attributes: %s" % str(e)
            sys.exit(CRITICAL)

        # Get the accounts that changed their country membership and remove them from done_accounts
        # -> this ensures their quota is reset for all RSEs even if the RSE size was not changed.
        reset_accounts = set()
        for country in country_accounts.keys():
            if country not in country_accounts_history:
                print "Country not in cache: %s - all quotas of their members will be reset: %s " % (country, str(",".join(country_accounts[country])))
                reset_accounts.update(country_accounts[country])
            else:
                diff = set(country_accounts_history[country]).symmetric_difference(country_accounts[country])
                if len(diff) > 0:
                    print "These accounts changed their country membership - their quotas will be reset: %s" % (str(",").join(diff))
                    reset_accounts.update(diff)
        for acc in reset_accounts:
            if acc in done_accounts:
                done_accounts.remove(acc)

        for ddmendpoint in data:

            # Check if RSE exists
            name = ddmendpoint['name']
            try:
                rse.get_rse(name)
            except RSENotFound:
                print "WARNING: RSE %s missing in Rucio" % name
                result = WARNING
                continue

            if ddmendpoint['phys_groups']:
                set_group_quotas(ddmendpoint['quotas'])

            if ddmendpoint['type'] == 'LOCALGROUPDISK':
                try:
                    country = rse.list_rse_attributes(name)['country']
                except:
                    print "WARNING: No country defined for %s" % name
                    result = WARNING
                    continue

                if country not in country_accounts:
                    print "No accounts in country %s" % country
                else:
                    set_user_quotas(name, country_accounts[country], 0.95, 'LOCALGROUPDISK', overwrite=False)

            if ddmendpoint['type'] in ('SCRATCHDISK', 'USERDISK', 'TEST'):
                set_user_quotas(name, all_accounts, 0.5, ddmendpoint['type'])

        # Write back cache files
        try:
            with open(accounts_cache, 'w') as f:
                json.dump(all_accounts, f)
        except:
            print "Failed to write accounts cache"

        try:
            with open(rse_capacity_cache, 'w') as f:
                json.dump(rse_capacity, f)
        except:
            print "Failed to write RSE capacity cache"

        try:
            with open(country_accounts_cache, 'w') as f:
                json.dump(country_accounts, f)
        except:
            print "Failed to write country accounts cache"

    except Exception as e:
        print traceback.format_exc()
        result = CRITICAL

    sys.exit(result)
