#!/usr/bin/env python
# Copyright European Organization for Nuclear Research (CERN)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#                       http://www.apache.org/licenses/LICENSE-2.0
#
# Authors:
# - David Cameron, <david.cameron@cern.ch>, 2015
# - Vincent Garonne, <vincent.garonne@cern.ch>, 2015
# - Cedric Serfon, <cedric.serfon@cern.ch>, 2016

#
# Set quotas for physgroups
# Set quotas for users on their country's localgroupdisk
# Set quotas for all users on scratchdisk


import json
import requests
import sys

from rucio.api.rse import get_rse, parse_rse_expression, get_rse_usage, list_rse_attributes
from rucio.api.account import list_accounts
from rucio.api.account_limit import set_account_limit
from rucio.db.sqla.session import get_session
from rucio.common.exception import RucioException, RSENotFound, AccountNotFound


UNKNOWN = 3
CRITICAL = 2
WARNING = 1
OK = 0

RESULT = OK


def get_srm_total(name, ddm_type):
    global RESULT
    try:
        capacity = get_rse_usage(name, 'root', source='srm')
        tier = list_rse_attributes(name)['tier']
    except RucioException as error:
        if ddm_type == 'TEST':
            print "Ignoring failed 'get capacity' for TEST endpoint: %s: %s" % (name, str(error))
        else:
            print "WARNING: Could not get capacity of %s: %s" % (name, str(error))
            RESULT = WARNING
        return None

    try:
        srmtotal = capacity[0]['total']
    except:
        try:
            capacity = get_rse_usage(name, 'root', source='gsiftp')
            srmtotal = capacity[0]['total']
        except:
            if tier in ['0', '1', '2'] and ddm_type != 'TEST':
                print "WARNING: No srm info for %s" % name
                RESULT = WARNING
        return None

    return srmtotal


def set_group_quotas():
    global RESULT
    quotas = {}
    session = get_session()
    query = '''select account, atlas_rucio.id2rse(c.rse_id), bytes from atlas_rucio.account_limits c,
               (select rse_id, rse, value from atlas_rucio.rses a, atlas_rucio.rse_attr_map b where a.id=b.rse_id and b.key='physgroup' and b.value!='None') d
               where c.rse_id=d.rse_id and c.account=d.value'''
    try:
        result = session.execute(query)
        for account, rse, bytes_size in result:
            quotas[account, rse] = bytes_size
    except Exception, error:
        print error

    url = 'http://atlas-agis-api.cern.ch/request/ddmendpoint/query/list/?json&state=ACTIVE&site_state=ACTIVE'
    try:
        resp = requests.get(url=url)
        data = json.loads(resp.content)
    except Exception as error:
        print "Failed to load info from AGIS: %s" % str(error)
        RESULT = CRITICAL

    for ddmendpoint in data:
        # Check if RSE exists
        name = ddmendpoint['name']
        try:
            get_rse(name)
        except RSENotFound:
            print "WARNING: RSE %s missing in Rucio" % name
            RESULT = WARNING
            continue

        if ddmendpoint['phys_groups']:
            for vomsgroup in ddmendpoint['quotas']:
                group = vomsgroup.split('/')[2]
                if group == ddmendpoint['phys_groups'][0]:
                    if (group, name) not in quotas:
                        print 'Will add quota on RSE %s for group %s' % (name, group)
                        try:
                            set_account_limit(group, name, ddmendpoint['quotas'][vomsgroup] * 1000 ** 4, 'root')
                        except AccountNotFound, error:
                            print error
                    elif ddmendpoint['quotas'][vomsgroup] == 999999:
                        if quotas[(group, name)] != -1:
                            print 'Infinite quota defined in AGIS for group %s on %s' % (group, name)
                            try:
                                set_account_limit(group, name, -1, 'root')
                            except AccountNotFound, error:
                                print error
                    elif quotas[(group, name)] != ddmendpoint['quotas'][vomsgroup] * 1000 ** 4:
                        print 'On %s quota for %s differs Rucio : %s vs AGIS : %s' % (name, account, quotas[(group, name)], ddmendpoint['quotas'][vomsgroup] * 1000 ** 4)
                        try:
                            set_account_limit(group, name, ddmendpoint['quotas'][vomsgroup] * 1000 ** 4, 'root')
                        except AccountNotFound, error:
                            print error


def set_user_quotas(ddm_type, fraction):
    global RESULT
    quotas = {}
    accounts = {}
    session = get_session()
    user_accounts = [acc['account'] for acc in list_accounts({'account_type': 'USER'})]
    for rse in parse_rse_expression('type=%s' % (ddm_type)):
        total_space = get_srm_total(rse, ddm_type)
        val = list_rse_attributes(rse)
        if total_space:
            quotas[rse] = float(val.get('default_account_limit_bytes', float(total_space) * fraction))
        elif 'default_account_limit_bytes' in val:
            quotas[rse] = float(val['default_account_limit_bytes'])

    query = '''select account, atlas_rucio.id2rse(rse_id), bytes from atlas_rucio.account_limits where
rse_id in (select id from atlas_rucio.rses a, atlas_rucio.rse_attr_map b where a.id=b.rse_id and b.key='type' and b.value='%s') ''' % (ddm_type)
    try:
        result = session.execute(query)
        for account, rse, bytes_size in result:
            if account not in accounts:
                accounts[account] = {}
            accounts[account][rse] = bytes_size

        for account in user_accounts:
            if account not in accounts:
                print '%s is missing' % (account)
                for rse in quotas:
                    print 'Will add quota for account %s on %s' % (account, rse)
                    try:
                        set_account_limit(account, rse, quotas[rse], 'root')
                    except AccountNotFound, error:
                        print error

            else:
                for rse in accounts[account]:
                    if abs(accounts[account][rse] - quotas[rse]) > 1000:
                        print '%s %s : Defined quota %s different from expected quota %s' % (rse, account, accounts[account][rse], quotas[rse])
                        try:
                            set_account_limit(account, rse, quotas[rse], 'root')
                        except AccountNotFound, error:
                            print error

    except Exception, error:
        print error
        RESULT = CRITICAL


def set_localgroupdisk_quotas(fraction):
    global RESULT
    countries = {}
    dict_sites = {}
    total_space = {}
    session = get_session()
    default_values = {}

    try:
        query = '''select atlas_rucio.id2rse(rse_id), value from atlas_rucio.rses a, atlas_rucio.rse_attr_map b
                   where a.id=b.rse_id and b.key='country' and a.rse like '%LOCALGROUPDISK' and a.deleted!=1'''
        result = session.execute(query)
        for rse, country in result:
            if country not in dict_sites:
                dict_sites[country] = []
            dict_sites[country].append(rse)
            if rse not in total_space:
                total_space[rse] = get_srm_total(rse, 'LOCALGROUPDISK')
                val = list_rse_attributes(rse)
                if 'default_account_limit_bytes' in val:
                    default_values[rse] = float(val['default_account_limit_bytes'])
    except Exception, error:
        print error
        RESULT = CRITICAL

    try:
        query = '''select a.account, a.key, a.value, a.updated_at, a.created_at from atlas_rucio.account_attr_map a, atlas_rucio.accounts b
                   where a.account=b.account and key like 'country-%' and b.status='ACTIVE' '''
        result = session.execute(query)
        for account, key, value, _, _ in result:
            if key not in countries:
                countries[key] = []
            countries[key].append(account)
    except Exception, error:
        print error
        RESULT = CRITICAL

    for country in countries:
        accounts = {}
        try:
            country_short = country.split('-')[1]
            query = '''select account, atlas_rucio.id2rse(rse_id), bytes from atlas_rucio.account_limits where
                       account in (select c.account from atlas_rucio.account_attr_map c, atlas_rucio.accounts d where d.status='ACTIVE'
                       and c.account=d.account and key='%s')
                       and rse_id in (select id from atlas_rucio.rses a, atlas_rucio.rse_attr_map b
                       where a.id=b.rse_id and b.key='country' and b.value='%s' and a.rse like '%%LOCALGROUPDISK') ''' % (country, country_short)
            result = session.execute(query)
            for account, rse, bytes_size in result:
                if account not in accounts:
                    accounts[account] = {}
                accounts[account][rse] = bytes_size
            for account in countries[country]:
                sites_with_no_quota = []
                if account not in accounts or accounts[account] == {}:
                    sites_with_no_quota = dict_sites[country_short]
                    print '%s : %s account is missing quota on sites : %s' % (country, account, sites_with_no_quota)
                elif len(accounts[account]) < len(dict_sites[country_short]):
                    sites_with_no_quota = list(set(dict_sites[country_short]) - set(accounts[account]))
                    print '%s : %s account is missing quota on some sites %s vs %s : %s' % (country, account, len(accounts[account]), len(dict_sites[country_short]), sites_with_no_quota)
                for rse in sites_with_no_quota:
                    if rse in total_space and total_space[rse]:
                        quota = default_values.get(rse, total_space[rse] * fraction)
                        print "Set quota of %dTB on %s for %s" % (quota / 1000 ** 4, rse, account)
                        try:
                            set_account_limit(account, rse, quota, 'root')
                        except AccountNotFound, error:
                            print error
        except Exception, error:
            print error
            RESULT = CRITICAL

if __name__ == '__main__':
    # For group the quota is set and updated using the AGIS information
    set_group_quotas()
    # For the user area USERDISK and SCRATCHDISK, the quota is created and updated each time the size of the space token change
    # If the RSEs have the res attribute default_account_limit_bytes set, the value will overwrite the default value (50 % of the space)
    set_user_quotas('USERDISK', 0.5)
    set_user_quotas('SCRATCHDISK', 0.5)
    # For the LOCALGROUPDISK area, the quota is created once but never updated
    # The value set is 95% of the space or default_account_limit_bytes if it is set as RSE attribute
    set_localgroupdisk_quotas(0.95)
    sys.exit(RESULT)
