#!/usr/bin/env python
# Copyright European Organization for Nuclear Research (CERN)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Authors:
# - Wen Guan, <wen.guan@cern.ch>, 2015

import json
import sys
import traceback
import urllib2

from rucio.db.sqla.constants import RSEType
from rucio.core import request as request_core
from rucio.core.distance import add_distance, get_distances, update_distances
from rucio.core.rse import list_rses


UNKNOWN = 3
CRITICAL = 2
WARNING = 1
OK = 0


def get_agis_sitenames():
    url = 'http://atlas-agis-api.cern.ch/request/ddmendpoint/query/list/?json'
    try:
        result = {}
        u = urllib2.urlopen(url)
        content = u.read()
        rses = json.loads(content)
        for item in rses:
            rse = item['name']
            sitename = item['site'].upper()
            result[rse] = sitename
        return OK, result
    except:
        return WARNING, "Failed to load rse-sitename data from url=%s, error: %s" % (url, traceback.format_exc())


def get_agis_distances():
    url = 'http://atlas-agis-api.cern.ch/request/site/query/list_links/?json'
    try:
        top_distance = 0
        result = {}
        u = urllib2.urlopen(url)
        content = u.read()
        site_list = json.loads(content)
        for item in site_list:
            if 'src' in item and 'dst' in item and 'closeness' in item:
                dst = item['dst'].upper()
                src = item['src'].upper()
                if src not in result:
                    result[src] = {}
                result[src][dst] = item['closeness']

                # fix transfer inside the same site
                result[src][src] = 0
                if dst not in result:
                    result[dst] = {}
                result[dst][dst] = 0

                if item['closeness'] > top_distance:
                    top_distance = item['closeness']
        return OK, top_distance, result
    except:
        return WARNING, None, "Failed to load distance data from url=%s, error: %s" % (url, traceback.format_exc())


def get_downtime_list():
    try:
        unavailable_read_rses = list_rses(filters={'availability_read': False})
        unavailable_read_rse_ids = [r['id'] for r in unavailable_read_rses]
        return OK, unavailable_read_rse_ids
    except:
        return WARNING, "Failed to get downtime list: %s" % traceback.format_exc()


def get_rse_distances():
    try:
        rows = get_distances()
        distances = {}
        for row in rows:
            src_rse_id = row['src_rse_id']
            dest_rse_id = row['dest_rse_id']
            if src_rse_id not in distances:
                distances[src_rse_id] = {}
            distances[src_rse_id][dest_rse_id] = {'distance': row['agis_distance'], 'ranking': row['ranking']}
        return OK, distances
    except:
        return WARNING, "Failed to get rse distances: %s" % traceback.format_exc()


def get_rses(sitenames):
    try:
        rses = list_rses()
        result = []
        for rse in rses:
            if rse['deleted'] or rse['staging_area']:
                continue
            if rse['rse'] not in sitenames:
                print "Cannot find site name for rse %s" % rse['rse']
                continue
            result.append(rse)
        return OK, result
    except:
        return WARNING, "Failed to get all active rses: %s" % traceback.format_exc()


def get_heavy_load_rses(threshold=5000):
    try:
        loads = request_core.get_heavy_load_rses(threshold=threshold)
        result = {}
        for load in loads:
            result[load['rse_id']] = load['load']
        return OK, result
    except:
        return WARNING, "Failed to get heavy load rses: %s" % traceback.format_exc()


if __name__ == '__main__':

    threshold = 10000

    retVal, result = get_agis_sitenames()
    if retVal != OK:
        print result
        sys.exit(retVal)
    sitenames = result

    retVal, top_distance, result = get_agis_distances()
    if retVal != OK:
        print result
        sys.exit(retVal)
    agis_distances = result

    retVal, result = get_downtime_list()
    if retVal != OK:
        print result
        sys.exit(retVal)
    downtime_list = result

    retVal, result = get_rse_distances()
    if retVal != OK:
        print result
        sys.exit(retVal)
    old_distances = result

    retVal, result = get_rses(sitenames)
    if retVal != OK:
        print result
        sys.exit(retVal)
    rses = result

    retVal, result = get_heavy_load_rses(threshold)
    if retVal != OK:
        print result
        sys.exit(retVal)
    heavy_load_rses = result

    top_rank = top_distance + 2

    for src_rse in rses:
        src_sitename = sitenames[src_rse['rse']]
        src_rse_id = src_rse['id']

        for dest_rse in rses:
            dest_sitename = sitenames[dest_rse['rse']]
            dest_rse_id = dest_rse['id']

            if src_sitename in agis_distances and dest_sitename in agis_distances[src_sitename]:
                if agis_distances[src_sitename][dest_sitename] > -1:
                    distance = agis_distances[src_sitename][dest_sitename]
                else:
                    distance = None
            else:
                # for site which is not in agis distance
                distance = top_distance / 2

            if src_sitename in downtime_list:
                ranking = 0
            else:
                if distance is None:
                    ranking = None
                else:
                    ranking = top_rank - distance

                if src_rse['rse_type'] == RSEType.TAPE:
                    # lower down TAPE rank
                    ranking = 1
                if src_rse_id in heavy_load_rses.keys():
                    ranking -= heavy_load_rses[src_rse_id] / threshold
                    # print "RSE %s load %s is too heavy, decrease its ranking to %s" % (src_rse_id, heavy_load_rses[src_rse_id], ranking)

            if src_rse_id in old_distances and dest_rse_id in old_distances[src_rse_id]:
                if old_distances[src_rse_id][dest_rse_id]['distance'] == distance and old_distances[src_rse_id][dest_rse_id]['ranking'] == ranking:
                    continue
                else:
                    """
                    print 'update src: %s, dest: %s, old_distance: %s, new_distance:%s, old_rank: %s, new_rank:%s' % (src_rse_id,
                                                                                                                      dest_rse_id,
                                                                                                                      old_distances[src_rse_id][dest_rse_id]['distance'],
                                                                                                                      distance,
                                                                                                                      old_distances[src_rse_id][dest_rse_id]['ranking'],
                                                                                                                      ranking)
                    """
                    update_distances(src_rse_id=src_rse_id, dest_rse_id=dest_rse_id, ranking=ranking, agis_distance=distance)
            else:
                # print 'add'
                add_distance(src_rse_id=src_rse_id, dest_rse_id=dest_rse_id, ranking=ranking, agis_distance=distance)
    sys.exit(OK)
