#!/usr/bin/env python
# Copyright European Organization for Nuclear Research (CERN)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Authors:
# - Wen Guan, <wen.guan@cern.ch>, 2015-2016

import json
import requests
import sys
import traceback
import urllib2
import urlparse

from rucio.db.sqla.constants import RSEType
from rucio.common.config import config_get
from rucio.core import request as request_core
from rucio.core.distance import add_distance_short, get_distances, update_distances_short
from rucio.core.rse import list_rses, get_rse_attribute

UNKNOWN = 3
CRITICAL = 2
WARNING = 1
OK = 0

__USERCERT = config_get('conveyor', 'usercert')


def get_agis_sitenames():
    url = 'http://atlas-agis-api.cern.ch/request/ddmendpoint/query/list/?json'
    try:
        result = {}
        u = urllib2.urlopen(url)
        content = u.read()
        rses = json.loads(content)
        for item in rses:
            rse = item['name']
            sitename = item['site'].upper()
            result[rse] = {'sitename': sitename,
                           'protocols': [protocol for protocol in item['protocols'].keys()],
                           'main_fts': item['servedrestfts']['MASTER'][0] if 'MASTER' in item['servedrestfts'] else None}
        return OK, result
    except:
        return WARNING, "Failed to load rse-sitename data from url=%s, error: %s" % (url, traceback.format_exc())


def get_agis_distances():
    url = 'http://atlas-agis-api.cern.ch/request/site/query/list_links/?json'
    try:
        top_distance = 0
        result = {}
        u = urllib2.urlopen(url)
        content = u.read()
        site_list = json.loads(content)
        for item in site_list:
            if 'src' in item and 'dst' in item and 'closeness' in item:
                dst = item['dst'].upper()
                src = item['src'].upper()
                if src not in result:
                    result[src] = {}
                result[src][dst] = item['closeness']

                # fix transfer inside the same site
                result[src][src] = 0
                if dst not in result:
                    result[dst] = {}
                result[dst][dst] = 0

                if item['closeness'] > top_distance:
                    top_distance = item['closeness']
        return OK, top_distance, result
    except:
        return WARNING, None, "Failed to load distance data from url=%s, error: %s" % (url, traceback.format_exc())


def get_active_ftses():
    ftses = []
    for fts_hosts in get_rse_attribute(key='fts'):
        for fts_host in fts_hosts.split(","):
            if fts_host not in ftses:
                ftses.append(fts_host)
    return ftses


def get_fts_transfer_summary(fts_host):
    try:
        result = {}
        url = fts_host.replace('8446', '8449')
        url = url + '/fts3/ftsmon/overview?vo=atlas&time_window=1&page=all'
        r = requests.get('%s' % url,
                         verify=False,
                         cert=(__USERCERT, __USERCERT),
                         headers={'Content-Type': 'application/json'})
        if r and r.status_code == 200:
            resp = r.json()
            for item in resp['overview']['items']:
                if item['vo_name'] == 'atlas':
                    key = item['source_se'] + "#" + item['dest_se']
                    if key not in result:
                        result[key] = {"submitted": item["submitted"] if "submitted" in item else 0,
                                       "active": item["active"] if "active" in item else 0,
                                       "finished": item["finished"] if "finished" in item else 0,
                                       "failed": item["failed"] if "failed" in item else 0,
                                       "transfer_speed": item["current"] if "current" in item else 0}
                    else:
                        print "Duplicated key %s: %s" % (key, result[key])
            return result
        else:
            print "Failed to get fts %s transfer summary, error: %s" % (fts_host, r.text if r is not None else r)
    except:
        print "Failed to get fts %s transfer summary, error: %s" % (fts_host, traceback.format_exc())
    return None


def get_ftses_transfer_summary():
    try:
        result = {}
        fts_hosts = get_active_ftses()
        for fts_host in fts_hosts:
            fts_summary = get_fts_transfer_summary(fts_host)
            if fts_summary:
                result[fts_host] = fts_summary
            else:
                result[fts_host] = []
        return OK, result
    except:
        return WARNING, "Failed to get ftses summary, error: %s" % (fts_host, traceback.format_exc())


def get_fts_info(fts_summary, src_protocols, dest_protocols):
    try:
        for src_protocol in src_protocols:
            parsed = urlparse.urlparse(src_protocol)
            src_name = parsed.scheme + "://" + parsed.netloc.partition(':')[0]
            for dest_protocol in dest_protocols:
                parsed = urlparse.urlparse(dest_protocol)
                dest_name = parsed.scheme + "://" + parsed.netloc.partition(':')[0]
                key = src_name + "#" + dest_name
                if key in fts_summary:
                    return fts_summary[key]
    except:
        print "Failed to get fts info: %s" % traceback.format_exc()
    return None


def get_downtime_list():
    try:
        unavailable_read_rses = list_rses(filters={'availability_read': False})
        unavailable_read_rse_ids = [r['id'] for r in unavailable_read_rses]
        return OK, unavailable_read_rse_ids
    except:
        return WARNING, "Failed to get downtime list: %s" % traceback.format_exc()


def get_rse_distances():
    try:
        rows = get_distances()
        distances = {}
        for row in rows:
            src_rse_id = row['src_rse_id']
            dest_rse_id = row['dest_rse_id']
            if src_rse_id not in distances:
                distances[src_rse_id] = {}
            row['distance'] = row['agis_distance']
            distances[src_rse_id][dest_rse_id] = row
        return OK, distances
    except:
        return WARNING, "Failed to get rse distances: %s" % traceback.format_exc()


def get_rses(sitenames):
    try:
        rses = list_rses()
        result = []
        for rse in rses:
            if rse['deleted'] or rse['staging_area']:
                continue
            if rse['rse'] not in sitenames:
                print "Cannot find site name for rse %s" % rse['rse']
                continue
            result.append(rse)
        return OK, result
    except:
        return WARNING, "Failed to get all active rses: %s" % traceback.format_exc()


def get_heavy_load_rses(threshold=5000):
    try:
        loads = request_core.get_heavy_load_rses(threshold=threshold)
        result = {}
        for load in loads:
            result[load['rse_id']] = load['load']
        return OK, result
    except:
        return WARNING, "Failed to get heavy load rses: %s" % traceback.format_exc()


def distance_changed(old_distance, new_distance):
    # keys = ['ranking', 'agis_distance', 'geoip_distance', 'active', 'submitted', 'finished', 'failed', 'transfer_speed']
    keys = ['ranking', 'agis_distance']
    # print old_distance
    # print new_distance
    for key in keys:
        old_value = old_distance.get(key, None)
        new_value = new_distance.get(key, None)
        if old_value != new_value:
            return True
    return False


def get_ranking(ranking, fts_info, threshold, speed_rank, heavy_load_rse=False):
    if fts_info['submitted'] < threshold and not heavy_load_rse:
        # not too many queued transfers, high rank for fast link
        if fts_info['transfer_speed']:
            ranking += fts_info['transfer_speed'] / speed_rank
    else:
        if fts_info['submitted']:
            if fts_info['finished']:
                ranking -= fts_info['submitted'] / fts_info['finished']
            else:
                ranking -= fts_info['submitted']
    return int(ranking)


if __name__ == '__main__':

    threshold = 1000
    speed_rank = 10  # MB/s, every 10 MB/s is one rank

    retVal, result = get_agis_sitenames()
    if retVal != OK:
        print result
        sys.exit(retVal)
    sitenames = result

    retVal, top_distance, result = get_agis_distances()
    if retVal != OK:
        print result
        sys.exit(retVal)
    agis_distances = result

    retVal, result = get_downtime_list()
    if retVal != OK:
        print result
        sys.exit(retVal)
    downtime_list = result

    retVal, result = get_rse_distances()
    if retVal != OK:
        print result
        sys.exit(retVal)
    old_distances = result

    retVal, result = get_rses(sitenames)
    if retVal != OK:
        print result
        sys.exit(retVal)
    rses = result

    retVal, result = get_heavy_load_rses(threshold)
    if retVal != OK:
        print result
        sys.exit(retVal)
    heavy_load_rses = result

    retVal, result = get_ftses_transfer_summary()
    if retVal != OK:
        print result
        sys.exit(retVal)
    fts_summary = result

    top_rank = top_distance + 2

    for src_rse in rses:
        src_sitename = sitenames[src_rse['rse']]['sitename']
        src_rse_id = src_rse['id']

        for dest_rse in rses:
            dest_sitename = sitenames[dest_rse['rse']]['sitename']
            dest_rse_id = dest_rse['id']
            main_fts = sitenames[dest_rse['rse']]['main_fts']

            fts_info = None
            if main_fts and main_fts in fts_summary:
                fts_info = get_fts_info(fts_summary[main_fts], sitenames[src_rse['rse']]['protocols'], sitenames[dest_rse['rse']]['protocols'])
            if fts_info is None:
                fts_info = {'active': None, 'failed': None, 'finished': None, 'transfer_speed': None, 'submitted': None}

            if src_sitename in agis_distances and dest_sitename in agis_distances[src_sitename]:
                if agis_distances[src_sitename][dest_sitename] > -1:
                    distance = agis_distances[src_sitename][dest_sitename]
                else:
                    distance = None
            else:
                # for site which is not in agis distance
                distance = top_distance / 2

            if src_sitename in downtime_list:
                ranking = 0
            else:
                if distance is None:
                    ranking = None
                else:
                    ranking = top_rank - distance

                if src_rse['rse_type'] == RSEType.TAPE:
                    # lower down TAPE rank
                    ranking = 1

            is_heavy_load_rse = True if src_rse_id in heavy_load_rses else False
            new_distance = {'ranking': None if ranking is None else get_ranking(ranking, fts_info, threshold, speed_rank, heavy_load_rse=is_heavy_load_rse),
                            'agis_distance': distance}

            if src_rse_id in old_distances and dest_rse_id in old_distances[src_rse_id]:
                if not distance_changed(old_distances[src_rse_id][dest_rse_id], new_distance):
                    continue
                else:
                    """
                    print 'update src: %s, dest: %s, old_distance: %s, new_distance:%s, old_rank: %s, new_rank:%s' % (src_rse_id,
                                                                                                                      dest_rse_id,
                                                                                                                      old_distances[src_rse_id][dest_rse_id]['distance'],
                                                                                                                      new_distance['agis_distance'],
                                                                                                                      old_distances[src_rse_id][dest_rse_id]['ranking'],
                                                                                                                      new_distance['ranking'])
                    """
                    update_distances_short(src_rse_id=src_rse_id, dest_rse_id=dest_rse_id, distance=new_distance)
            else:
                # print 'add'
                add_distance_short(src_rse_id=src_rse_id, dest_rse_id=dest_rse_id, distance=new_distance)
    sys.exit(OK)
