#!/usr/bin/env python
"""
 Copyright European Organization for Nuclear Research (CERN)

 Licensed under the Apache License, Version 2.0 (the "License");
 You may not use this file except in compliance with the License.
 You may obtain a copy of the License at
                       http://www.apache.org/licenses/LICENSE-2.0

 Authors:
 - David Cameron, <david.cameron@cern.ch>, 2015
 - Vincent Garonne, <vincent.garonne@cern.ch>, 2015
 - Cedric Serfon, <cedric.serfon@cern.ch>, 2016-2017


 Set quotas for physgroups
 Set quotas for users on their country's localgroupdisk
 Set quotas for all users on scratchdisk
"""

import json
import sys

import requests

from rucio.api.rse import get_rse, parse_rse_expression, get_rse_usage, list_rse_attributes
from rucio.api.account import list_accounts
from rucio.api.account_limit import set_account_limit
from rucio.api.config import get
from rucio.db.sqla.session import get_session
from rucio.common.exception import RucioException, RSENotFound, AccountNotFound


UNKNOWN = 3
CRITICAL = 2
WARNING = 1
OK = 0

RESULT = OK


def get_storage_total(name, ddm_type):
    global RESULT
    try:
        capacity = get_rse_usage(name, 'root', source='storage')
        tier = list_rse_attributes(name)['tier']
    except RucioException as error:
        if ddm_type == 'TEST':
            print "Ignoring failed 'get capacity' for TEST endpoint: %s: %s" % (name, str(error))
        elif tier in ['0', '1', '2']:
            print "WARNING: Could not get capacity of %s: %s" % (name, str(error))
            RESULT = WARNING
            return None
    try:
        total = capacity[0]['total']
        return total
    except IndexError as error:
        print 'No storage info for %s' % error
        return None


def set_group_quotas():
    global RESULT
    quotas = {}
    session = get_session()
    query = '''select account, atlas_rucio.id2rse(c.rse_id), bytes from atlas_rucio.account_limits c,
               (select rse_id, rse, value from atlas_rucio.rses a, atlas_rucio.rse_attr_map b where a.id=b.rse_id and b.key='physgroup' and b.value!='None' and a.deleted!=1) d
               where c.rse_id=d.rse_id and c.account=d.value'''
    try:
        result = session.execute(query)
        for account, rse, bytes_size in result:
            quotas[account, rse] = bytes_size
    except Exception as error:
        print error

    url = 'http://atlas-agis-api.cern.ch/request/ddmendpoint/query/list/?json&state=ACTIVE&site_state=ACTIVE'
    try:
        resp = requests.get(url=url)
        data = json.loads(resp.content)
    except Exception as error:
        print "Failed to load info from AGIS: %s" % str(error)
        RESULT = CRITICAL

    for ddmendpoint in data:
        # Check if RSE exists
        name = ddmendpoint['name']
        try:
            get_rse(name)
        except RSENotFound:
            print "WARNING: RSE %s missing in Rucio" % name
            RESULT = WARNING
            continue

        if ddmendpoint['phys_groups']:
            for vomsgroup in ddmendpoint['quotas']:
                group = vomsgroup.split('/')[2]
                if group == ddmendpoint['phys_groups'][0]:
                    if (group, name) not in quotas:
                        print 'Will add quota on RSE %s for group %s' % (name, group)
                        try:
                            set_account_limit(group, name, ddmendpoint['quotas'][vomsgroup] * 1000 ** 4, 'root')
                        except AccountNotFound as error:
                            print error
                    elif ddmendpoint['quotas'][vomsgroup] == 999999:
                        if quotas[(group, name)] != -1:
                            print 'Infinite quota defined in AGIS for group %s on %s' % (group, name)
                            try:
                                set_account_limit(group, name, -1, 'root')
                            except AccountNotFound as error:
                                print error
                    elif quotas[(group, name)] != ddmendpoint['quotas'][vomsgroup] * 1000 ** 4:
                        print 'On %s quota for %s differs Rucio : %s vs AGIS : %s' % (name, group, quotas[(group, name)], ddmendpoint['quotas'][vomsgroup] * 1000 ** 4)
                        try:
                            set_account_limit(group, name, ddmendpoint['quotas'][vomsgroup] * 1000 ** 4, 'root')
                        except AccountNotFound as error:
                            print error


def set_user_quotas(ddm_type, fraction):
    global RESULT
    quotas = {}
    accounts = {}
    session = get_session()
    user_accounts = [acc['account'] for acc in list_accounts({'account_type': 'USER'})]

    try:
        for rse in parse_rse_expression('type=%s' % (ddm_type)):
            total_space = get_storage_total(rse, ddm_type)
            val = list_rse_attributes(rse)
            if total_space:
                quotas[rse] = float(val.get('default_account_limit_bytes', float(total_space) * fraction))
            elif 'default_account_limit_bytes' in val:
                quotas[rse] = float(val['default_account_limit_bytes'])

        query = '''select account, atlas_rucio.id2rse(rse_id), bytes from atlas_rucio.account_limits where
rse_id in (select id from atlas_rucio.rses a, atlas_rucio.rse_attr_map b where a.id=b.rse_id and b.key='type' and b.value='%s' and a.deleted!=1) ''' % (ddm_type)
        result = session.execute(query)
        for account, rse, bytes_size in result:
            if account not in accounts:
                accounts[account] = {}
            accounts[account][rse] = bytes_size

        for account in user_accounts:
            if account not in accounts:
                print '%s is missing' % (account)
                for rse in quotas:
                    print 'Will add quota for account %s on %s' % (account, rse)
                    try:
                        set_account_limit(account, rse, quotas[rse], 'root')
                    except AccountNotFound as error:
                        print error

            else:
                for rse in accounts[account]:
                    if rse in quotas:
                        if abs(accounts[account][rse] - quotas[rse]) > 1000:
                            print '%s %s : Defined quota %s different from expected quota %s' % (rse, account, accounts[account][rse], quotas[rse])
                            try:
                                set_account_limit(account, rse, quotas[rse], 'root')
                            except AccountNotFound as error:
                                print error
                    else:
                        print '%s cannot be found in quotas dictionary' % (rse)

                for rse in quotas:
                    if rse not in accounts[account]:
                        print 'Will add quota for account %s on %s' % (account, rse)
                        try:
                            set_account_limit(account, rse, quotas[rse], 'root')
                        except AccountNotFound as error:
                            print error

    except Exception as error:
        print error
        RESULT = CRITICAL


def set_localgroupdisk_quotas(fraction):
    global RESULT
    countries = {}
    dict_sites = {}
    total_space = {}
    session = get_session()
    default_values = {}

    try:
        query = '''select atlas_rucio.id2rse(rse_id), value from atlas_rucio.rses a, atlas_rucio.rse_attr_map b
                   where a.id=b.rse_id and b.key='country' and a.rse like '%LOCALGROUPDISK' and a.deleted!=1'''
        result = session.execute(query)
        for rse, country in result:
            if country not in dict_sites:
                dict_sites[country] = []
            dict_sites[country].append(rse)
            if rse not in total_space:
                total_space[rse] = get_storage_total(rse, 'LOCALGROUPDISK')
                val = list_rse_attributes(rse)
                if 'default_account_limit_bytes' in val:
                    default_values[rse] = float(val['default_account_limit_bytes'])
    except Exception as error:
        print error
        RESULT = CRITICAL

    try:
        query = '''select a.account, a.key, a.value, a.updated_at, a.created_at from atlas_rucio.account_attr_map a, atlas_rucio.accounts b
                   where a.account=b.account and key like 'country-%' and b.status='ACTIVE' '''
        result = session.execute(query)
        for account, key, _, _, _ in result:
            if key not in countries:
                countries[key] = []
            countries[key].append(account)
    except Exception as error:
        print error
        RESULT = CRITICAL

    for country in countries:
        accounts = {}
        try:
            country_short = country.split('-')[1]
            query = '''select account, atlas_rucio.id2rse(rse_id), bytes from atlas_rucio.account_limits where
                       account in (select c.account from atlas_rucio.account_attr_map c, atlas_rucio.accounts d where d.status='ACTIVE'
                       and c.account=d.account and key='%s')
                       and rse_id in (select id from atlas_rucio.rses a, atlas_rucio.rse_attr_map b
                       where a.id=b.rse_id and b.key='country' and b.value='%s' and a.rse like '%%LOCALGROUPDISK') ''' % (country, country_short)
            result = session.execute(query)
            for account, rse, bytes_size in result:
                if account not in accounts:
                    accounts[account] = {}
                accounts[account][rse] = bytes_size
            for account in countries[country]:
                sites_with_no_quota = []
                if account not in accounts or accounts[account] == {}:
                    sites_with_no_quota = dict_sites[country_short]
                    print '%s : %s account is missing quota on sites : %s' % (country, account, sites_with_no_quota)
                elif len(accounts[account]) < len(dict_sites[country_short]):
                    sites_with_no_quota = list(set(dict_sites[country_short]) - set(accounts[account]))
                    print '%s : %s account is missing quota on some sites %s vs %s : %s' % (country, account, len(accounts[account]), len(dict_sites[country_short]), sites_with_no_quota)
                for rse in sites_with_no_quota:
                    quota = None
                    if rse in default_values:
                        quota = default_values[rse]
                    elif rse in total_space and total_space[rse]:
                        quota = total_space[rse] * fraction
                    if quota:
                        print "Set quota of %dTB on %s for %s" % (quota / 1000 ** 4, rse, account)
                        try:
                            set_account_limit(account, rse, quota, 'root')
                        except AccountNotFound as error:
                            print error
        except Exception as error:
            print error
            RESULT = CRITICAL


if __name__ == '__main__':
    try:
        QUOTA_SCRATCHDISK = get('quota', 'SCRATCHDISK', 'root')
        QUOTA_SCRATCHDISK = float(QUOTA_SCRATCHDISK) / 100
    except:
        QUOTA_SCRATCHDISK = 0.3
    try:
        QUOTA_LOCALGROUPDISK = get('quota', 'LOCALGROUPDISK', 'root')
        QUOTA_LOCALGROUPDISK = float(QUOTA_LOCALGROUPDISK) / 100
    except:
        QUOTA_LOCALGROUPDISK = 0.95

    # For group the quota is set and updated using the AGIS information
    set_group_quotas()
    # For the SCRATCHDISK area, the quota is created and updated each time the size of the space token change
    # If the RSEs have the res attribute default_account_limit_bytes set, the value will overwrite the default value
    set_user_quotas('SCRATCHDISK', QUOTA_SCRATCHDISK)
    # For the LOCALGROUPDISK area, the quota is created once but never updated
    # The value set is 95% of the space or default_account_limit_bytes if it is set as RSE attribute
    set_localgroupdisk_quotas(QUOTA_LOCALGROUPDISK)
    sys.exit(RESULT)
