#!/usr/bin/env python
#
# Copyright European Organization for Nuclear Research (CERN)

# Licensed under the Apache License, Version 2.0 (the "License");
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#                       http://www.apache.org/licenses/LICENSE-2.0
#
# Authors:
# - David Cameron, <david.cameron@cern.ch>, 2015
# - Vincent Garonne, <vincent.garonne@cern.ch>, 2015
# - Cedric Serfon, <cedric.serfon@cern.ch>, 2016-2019
#
#
# Set quotas for physgroups
# Set quotas for users on their country's localgroupdisk
# Set quotas for all users on scratchdisk
#
# PY3K COMPATIBLE

from __future__ import print_function
import json
import sys

import requests

from rucio.api.rse import get_rse, parse_rse_expression, get_rse_usage, list_rse_attributes
from rucio.api.account import list_accounts
from rucio.api.account_limit import set_account_limit
from rucio.api.config import get
from rucio.db.sqla.session import get_session
from rucio.common.exception import RucioException, RSENotFound, AccountNotFound


UNKNOWN = 3
CRITICAL = 2
WARNING = 1
OK = 0

RESULT = OK


def get_storage_total(name, ddm_type):
    global RESULT
    try:
        capacity = get_rse_usage(name, 'root', source='storage')
        tier = list_rse_attributes(name)['tier']
    except RucioException as error:
        if ddm_type == 'TEST':
            print("Ignoring failed 'get capacity' for TEST endpoint: %s: %s" % (name, str(error)))
        elif tier in ['0', '1', '2']:
            print("WARNING: Could not get capacity of %s: %s" % (name, str(error)))
            RESULT = WARNING
            return None
    try:
        total = capacity[0]['total']
        return total
    except IndexError as error:
        print('No storage info for %s' % error)
        return None


def set_group_quotas():
    global RESULT
    quotas = {}
    session = get_session()
    query = '''select account, atlas_rucio.id2rse(c.rse_id), bytes from atlas_rucio.account_limits c,
               (select rse_id, rse, value from atlas_rucio.rses a, atlas_rucio.rse_attr_map b where a.id=b.rse_id and b.key='physgroup' and b.value!='None' and a.deleted!=1) d
               where c.rse_id=d.rse_id and c.account=d.value'''
    try:
        result = session.execute(query)
        for account, rse, bytes_size in result:
            quotas[account, rse] = bytes_size
    except Exception as error:
        print(error)

    url = 'http://atlas-agis-api.cern.ch/request/ddmendpoint/query/list/?json&state=ACTIVE&site_state=ACTIVE'
    try:
        resp = requests.get(url=url)
        data = json.loads(resp.content)
    except Exception as error:
        print("Failed to load info from AGIS: %s" % str(error))
        RESULT = CRITICAL

    for ddmendpoint in data:
        # Check if RSE exists
        name = ddmendpoint['name']
        try:
            get_rse(name)
        except RSENotFound:
            print("WARNING: RSE %s missing in Rucio" % name)
            RESULT = WARNING
            continue

        if ddmendpoint['phys_groups']:
            for vomsgroup in ddmendpoint['quotas']:
                group = vomsgroup.split('/')[2]
                if group == ddmendpoint['phys_groups'][0]:
                    if (group, name) not in quotas:
                        print('Will add quota on RSE %s for group %s' % (name, group))
                        try:
                            set_account_limit(group, name, ddmendpoint['quotas'][vomsgroup] * 1000 ** 4, 'root')
                        except AccountNotFound as error:
                            print(error)
                    elif ddmendpoint['quotas'][vomsgroup] == 999999:
                        if quotas[(group, name)] != -1:
                            print('Infinite quota defined in AGIS for group %s on %s' % (group, name))
                            try:
                                set_account_limit(group, name, -1, 'root')
                            except AccountNotFound as error:
                                print(error)
                    elif quotas[(group, name)] != ddmendpoint['quotas'][vomsgroup] * 1000 ** 4:
                        print('On %s quota for %s differs Rucio : %s vs AGIS : %s' % (name, group, quotas[(group, name)], ddmendpoint['quotas'][vomsgroup] * 1000 ** 4))
                        try:
                            set_account_limit(group, name, ddmendpoint['quotas'][vomsgroup] * 1000 ** 4, 'root')
                        except AccountNotFound as error:
                            print(error)


def set_user_quotas(ddm_type, fraction, absolute=0, account_type='USER'):
    global RESULT
    quotas = {}
    accounts = {}
    session = get_session()
    user_accounts = [acc['account'] for acc in list_accounts({'account_type': account_type})]

    try:
        for rse in parse_rse_expression('type=%s' % (ddm_type)):
            total_space = get_storage_total(rse, ddm_type)
            val = list_rse_attributes(rse)
            if total_space:
                def_quota = min(absolute, float(total_space) * fraction) if absolute else float(total_space) * fraction
                quotas[rse] = float(val.get('default_account_limit_bytes', def_quota))
            elif 'default_account_limit_bytes' in val:
                quotas[rse] = float(val['default_account_limit_bytes'])

        query = '''select account, atlas_rucio.id2rse(rse_id), bytes from atlas_rucio.account_limits where
rse_id in (select id from atlas_rucio.rses a, atlas_rucio.rse_attr_map b where a.id=b.rse_id and b.key='type' and b.value='%s' and a.deleted!=1) ''' % (ddm_type)
        result = session.execute(query)
        for account, rse, bytes_size in result:
            if account not in accounts:
                accounts[account] = {}
            accounts[account][rse] = bytes_size

        for account in user_accounts:
            if account not in accounts:
                print('%s is missing' % (account))
                for rse in quotas:
                    print('Will add quota for account %s on %s' % (account, rse))
                    try:
                        set_account_limit(account, rse, quotas[rse], 'root')
                    except AccountNotFound as error:
                        print(error)

            else:
                for rse in accounts[account]:
                    if rse in quotas:
                        if abs(accounts[account][rse] - quotas[rse]) > 1000:
                            print('%s %s : Defined quota %s different from expected quota %s' % (rse, account, accounts[account][rse], quotas[rse]))
                            try:
                                set_account_limit(account, rse, quotas[rse], 'root')
                            except AccountNotFound as error:
                                print(error)
                    else:
                        print('%s cannot be found in quotas dictionary' % (rse))

                for rse in quotas:
                    if rse not in accounts[account]:
                        print('Will add quota for account %s on %s' % (account, rse))
                        try:
                            set_account_limit(account, rse, quotas[rse], 'root')
                        except AccountNotFound as error:
                            print(error)

    except Exception as error:
        print(error)
        RESULT = CRITICAL


def set_localgroupdisk_quotas(fraction, absolute):
    global RESULT
    countries = {}
    dict_sites = {}
    total_space = {}
    session = get_session()
    default_values = {}

    try:
        query = '''select atlas_rucio.id2rse(rse_id), value from atlas_rucio.rses a, atlas_rucio.rse_attr_map b
                   where a.id=b.rse_id and b.key='country' and a.rse like '%LOCALGROUPDISK' and a.deleted!=1'''
        result = session.execute(query)
        for rse, country in result:
            if country not in dict_sites:
                dict_sites[country] = []
            dict_sites[country].append(rse)
            if rse not in total_space:
                total_space[rse] = get_storage_total(rse, 'LOCALGROUPDISK')
                val = list_rse_attributes(rse)
                if 'default_account_limit_bytes' in val:
                    default_values[rse] = float(val['default_account_limit_bytes'])
    except Exception as error:
        print(error)
        RESULT = CRITICAL

    try:
        query = '''select a.account, a.key, a.value, a.updated_at, a.created_at from atlas_rucio.account_attr_map a, atlas_rucio.accounts b
                   where a.account=b.account and key like 'country-%' and b.status='ACTIVE' '''
        result = session.execute(query)
        for account, key, _, _, _ in result:
            if key not in countries:
                countries[key] = []
            countries[key].append(account)
    except Exception as error:
        print(error)
        RESULT = CRITICAL

    for country in countries:
        accounts = {}
        try:
            country_short = country.split('-')[1]
            query = '''select account, atlas_rucio.id2rse(rse_id), bytes from atlas_rucio.account_limits where
                       account in (select c.account from atlas_rucio.account_attr_map c, atlas_rucio.accounts d where d.status='ACTIVE'
                       and c.account=d.account and key='%s')
                       and rse_id in (select id from atlas_rucio.rses a, atlas_rucio.rse_attr_map b
                       where a.id=b.rse_id and b.key='country' and b.value='%s' and a.rse like '%%LOCALGROUPDISK') ''' % (country, country_short)
            result = session.execute(query)
            for account, rse, bytes_size in result:
                if account not in accounts:
                    accounts[account] = {}
                accounts[account][rse] = bytes_size
            for account in countries[country]:
                sites_with_no_quota = []
                if account not in accounts or accounts[account] == {}:
                    sites_with_no_quota = dict_sites[country_short]
                    print('%s : %s account is missing quota on sites : %s' % (country, account, sites_with_no_quota))
                elif len(accounts[account]) < len(dict_sites[country_short]):
                    sites_with_no_quota = list(set(dict_sites[country_short]) - set(accounts[account]))
                    print('%s : %s account is missing quota on some sites %s vs %s : %s' % (country, account, len(accounts[account]), len(dict_sites[country_short]), sites_with_no_quota))
                for rse in sites_with_no_quota:
                    quota = None
                    if rse in default_values:
                        quota = default_values[rse]
                    elif rse in total_space and total_space[rse]:
                        quota = min(absolute, float(total_space[rse]) * fraction) if absolute else float(total_space[rse]) * fraction
                        quota = total_space[rse] * fraction
                    if quota:
                        print("Set quota of %dTB on %s for %s" % (quota / 1000 ** 4, rse, account))
                        try:
                            set_account_limit(account, rse, quota, 'root')
                        except AccountNotFound as error:
                            print(error)
        except Exception as error:
            print(error)
            RESULT = CRITICAL


if __name__ == '__main__':
    try:
        REL_QUOTA_SCRATCHDISK = get('quota', 'rel_SCRATCHDISK', 'root')
        REL_QUOTA_SCRATCHDISK = float(REL_QUOTA_SCRATCHDISK) / 100
    except Exception:
        REL_QUOTA_SCRATCHDISK = 0.3
    try:
        ABS_QUOTA_SCRATCHDISK = get('quota', 'abs_SCRATCHDISK', 'root')
        ABS_QUOTA_SCRATCHDISK = ABS_QUOTA_SCRATCHDISK
    except Exception:
        ABS_QUOTA_SCRATCHDISK = 20000000000000
    try:
        REL_QUOTA_SCRATCHDISK_GROUP = get('quota', 'rel_SCRATCHDISK_group', 'root')
        REL_QUOTA_SCRATCHDISK_GROUP = float(REL_QUOTA_SCRATCHDISK_GROUP) / 100
    except Exception:
        REL_QUOTA_SCRATCHDISK_GROUP = 0.1
    try:
        ABS_QUOTA_SCRATCHDISK_GROUP = get('quota', 'abs_SCRATCHDISK_GROUP', 'root')
        ABS_QUOTA_SCRATCHDISK_GROUP = ABS_QUOTA_SCRATCHDISK_GROUP
    except Exception:
        ABS_QUOTA_SCRATCHDISK_GROUP = 10000000000000
    try:
        REL_QUOTA_LOCALGROUPDISK = get('quota', 'rel_LOCALGROUPDISK', 'root')
        REL_QUOTA_LOCALGROUPDISK = float(REL_QUOTA_LOCALGROUPDISK) / 100
    except Exception:
        REL_QUOTA_LOCALGROUPDISK = 0.95
    try:
        ABS_QUOTA_LOCALGROUPDISK = get('quota', 'abs_LOCALGROUPDISK', 'root')
        ABS_QUOTA_LOCALGROUPDISK = float(ABS_QUOTA_LOCALGROUPDISK) / 100
    except Exception:
        ABS_QUOTA_LOCALGROUPDISK = 0

    # For group the quota is set and updated using the AGIS information
    set_group_quotas()
    # For the SCRATSK area, the quota is created and updated each time the size of the space token change
    # If the RSEs have the res attribute default_account_limit_bytes set, the value will overwrite the default value
    set_user_quotas(ddm_type='SCRATCHDISK', fraction=REL_QUOTA_SCRATCHDISK, absolute=ABS_QUOTA_SCRATCHDISK, account_type='USER')
    set_user_quotas(ddm_type='SCRATCHDISK', fraction=REL_QUOTA_SCRATCHDISK_GROUP, absolute=ABS_QUOTA_SCRATCHDISK_GROUP, account_type='GROUP')
    # For the LOCALGROUPDISK area, the quota is created once but never updated
    # The value set is 95% of the space or default_account_limit_bytes if it is set as RSE attribute
    set_localgroupdisk_quotas(fraction=REL_QUOTA_LOCALGROUPDISK, absolute=ABS_QUOTA_LOCALGROUPDISK)
    sys.exit(RESULT)
