#!/usr/bin/env python
"""
    Very simple client script used to get nearest squid server using the RESTful API.
"""
try:
    import urllib.request as UrlRequest
    from urllib.error import URLError as UrlError
except:
    import urllib2 as UrlRequest
    from urllib2 import URLError as UrlError
import sys
import re
import os
import syslog
from optparse import OptionParser
from subprocess import Popen, PIPE
from multiprocessing import Process, Queue
import requests
import time
import netifaces
import copy
try:
    import queue as Q
except:
    import Queue as Q

from shoal_client import config as config

syslog.openlog("shoal-client")

server = config.shoal_server_url
default_http_proxy = config.default_squid_proxy
local_squid = None

data = None
server_available_private = {}
server_available_public = {}
server_squids = []
frontier = False
numsquids = 2
dump = False
best_http_proxy = ''
env_proxy = ''
unicode = str
timeout = 90
paths = config.paths

def get_local_squid(q, t):
    PORT = 50000
    BROADCAST_ID = "fna349fn" #to make sure we don't confuse or get confused by other programs

    from socket import socket, AF_INET, SOCK_DGRAM
    
    s = socket(AF_INET, SOCK_DGRAM) #create UDP socket
    s.bind(('', PORT))
    s.settimeout(t)
    
    try:
        data, addr = s.recvfrom(1024) #wait for a packet
        data = data.decode('utf-8')
        if data.startswith(BROADCAST_ID):
            # print("got service announcement from", data[len(BROADCAST_ID):])
            q.put(data[len(BROADCAST_ID):])
    except:
        # print("socket got a timeout")
        syslog.syslog(syslog.LOG_ERR, "Socket got a timeout")


def get_args():
    """
        gets server and dump variables from command line arguments
    """
    global server
    global dump
    global frontier
    global numsquids

    p = OptionParser()
    p.add_option("-s", "--server", action="store", type="string", dest="server",
                 help="Also needs string for specifying the shoal server to use. " +
                 "Takes precedence over the option in config file.")
    p.add_option("-d", "--dump", action="store_true", dest="dump",
                 help="Print closest proxies to terminal for debugging "+
                 "instead of executing a cvfms-talk command to update the active configuration.")
    p.add_option("-f", "--frontier", action="store_true", dest="frontier",
                 help="Outputs a string suitable for use as a frontier proxy eviroment variable" +
                 "instead of executing a cvfms-talk command to update the active configuration.")
    p.add_option("-n", "--squids", action="store", type="int", dest="numsquids",
                 help="Specifies the number of squids to retrieve" +
                 "from the shoal-server (default is 2).")
    (options, args) = p.parse_args()

    if options.server:
        server = options.server
    if options.dump:
        dump = True
    if options.frontier:
        frontier = True
    if options.numsquids:
        numsquids = options.numsquids

def convertServerData(val):
    """
        converts val to digits if it's not already or else return None
    """
    if val.isdigit():
        return int(val)
    else:
        try:
            return float(val)
        except:
            if "null" in val:
                return None
            else:
                return unicode(val.strip("\""))


def parseServerData(jsonStr):
    """
        creates a multidimensional server data dictionary indexed by
        unicode integers with dataTypes, geo_data and geoDataTypes. Each
        respective entry holds the appropriate dataTypes and geoDataTypes
        found in jsonStr
    """

    # TODO should load this from a config file as it has to match the server
    # Nested properties (i.e geo_data) needs to be handled separately
    dataTypes = ["load", "distance", "squid_port", "last_active", "created", \
                 "external_ip", "hostname", "public_ip", "private_ip", "local"]

    geoDataTypes = ["city", "region_name", "area_code", "time_zone", "dma_code", \
                    "metro_code", "country_code3", "latitude", "postal_code", \
                    "longitude", "country_code", "country_name", "continent"]

    # don't really care about data here
    # it is just a simple way to get number of nearest squids
    p = re.compile("\"" + dataTypes[0] + "\": ([^,}]+)")
    numNearestSquids = len(p.findall(jsonStr.decode('utf-8')))
    ## compiles regex "load": ([^,}]+), although it doesn't really matter that fact that it's a load
    ## this will find the number of above matches in jsonStr and return into numNearestSquids
    ## therefore each match in json is a 'nearest' squid

    # initalize the dictionaries
    outDict = {}
    for i in range(0, numNearestSquids):
        outDict[unicode(str(i))] = {}
        outDict[unicode(str(i))][unicode("geo_data")] = {}
    ## creates a multidimensional dict with each key 1 being u'i' (i in unicode)
    ## and key 2 being "geo_data" for all entries
    
    # TODO probably don't need seperate regexes
    # test using geodata one for both
    # for each item in dataTypes, compile a regex for that item and find all
    # the matches with jsonStr and put those matches in dataList.
    for dataType in dataTypes:
        p = re.compile("\"" + dataType + "\": ([^,]+)[,|}]")
        dataList = p.findall(jsonStr.decode('utf-8'))
        for i, val in enumerate(dataList):
            outDict[unicode(str(i))][unicode(dataType)] = convertServerData(val)
    ## outDict is a multidimensional dict that now holds a val in each dataType per i
    ## same as above just for geoDataTypes
    for geoDataType in geoDataTypes:
        p = re.compile("\"" + geoDataType + "\": (\"[^\"]*|[^,]*)")
        dataList = p.findall(jsonStr.decode('utf-8'))
        for i, val in enumerate(dataList):
            outDict[unicode(str(i))][unicode("geo_data")][unicode(geoDataType)] = convertServerData(val)
    ## outDict in geo_data for each geoDataType holds a val
    return outDict

def getString(content):
    try:
        formatted = bytes(content, 'utf-8')
    except:
        formatted = bytes(content) # for python 2
    return formatted

def functionalTest(ip, port):
    """
    A list of paths are tested to see all repos are working, returns the "True" if it is verified or "False" if it cannot be
    """
    proxystring = "http://%s:%s" % (ip, port)
    #set proxy
    proxies = {
        "http":proxystring,
    }
    for targeturl in paths:
        #if a url checks out testflag set to true, otherwise fails verification at end of loop
        testflag = False
        try:
            # print("Trying", targeturl)
            repo = re.search("cvmfs\/(.+?)(\/|\.)|opt\/(.+?)(\/|\.)", targeturl).group(1)
            if repo is None:
                repo = re.search("cvmfs\/(.+?)(\/|\.)|opt\/(.+?)(\/|\.)", targeturl).group(3)
            file = requests.get(targeturl, proxies=proxies, timeout=2)
            f = file.content
            for line in f.splitlines():            
                if line.startswith(getString('N')):
                    if getString(repo) in line:
                        testflag = True
            if testflag is False:
                syslog.syslog(syslog.LOG_ERR, "%s failed verification on %s" % (ip, targeturl))
                return None
        except:
            # note that this would catch any RE errors aswell but they are
            # specified in the config and all fit the pattern.
            syslog.syslog(syslog.LOG_ERR, "%s timeout or proxy error on %s repo" % (ip, targeturl))
            return None
    
    start_time = time.time()
    for targeturl in paths:
        try:
            file = requests.get(targeturl, proxies=proxies, timeout=2)
        except Exception as exe:
            syslog.syslog(syslog.LOG_ERR, "%s timeout or proxy error on %s repo" % (ip, targeturl))
            return None
    timer = time.time() - start_time
    return timer

def getAvailableIp(squid):
    """
    Functioanl test for return squids from broadcast and shoal-server
    Test both private and public ip and return private ip first if available
    """
    squid_private_ip = str(squid['private_ip']) if squid['private_ip'] else None
    squid_public_ip = str(squid['public_ip']) if squid['public_ip'] else None
    squid_port = squid['squid_port']    
    
    client_private_ip = None
    client_private_ip_not_found = False
    try:
        privAddressList = ['172.' + str(x) for x in range(16, 32)] + ['10.', '192.168.']
        PRIVATE_ADDRESS = tuple(privAddressList)
        for interface in netifaces.interfaces():
            try:
                for link in netifaces.ifaddresses(interface)[netifaces.AF_INET]:
                    if link['addr'].startswith(PRIVATE_ADDRESS):
                        client_private_ip = link['addr']
            except Exception:
                continue
    except:
        client_private_ip_not_found = True
        syslog.syslog(syslog.LOG_ERR, "Unable to use a private ip address of client")

    local = False
    try:
        local = squid['local']
    except:
        syslog.syslog(syslog.LOG_ERR, "Squid has no local attribute")
  
    return_squid_ip = None
    test_time_cost = None
    squid_type = 'public'
    if local and squid_private_ip:
        same_subnet = False
        if client_private_ip:
            squid_subnet = '.'.join(squid_private_ip.split('.')[:2])
            client_subnet = '.'.join(client_private_ip.split('.')[:2])
            same_subnet = squid_subnet == client_subnet
        if (same_subnet or client_private_ip_not_found):
            test_time_cost = functionalTest(squid_private_ip, squid_port)
            if test_time_cost:
                return_squid_ip = squid_private_ip
                squid_type = 'private'  
    if not return_squid_ip and squid_public_ip:
        test_time_cost = functionalTest(squid_public_ip, squid_port)
        if test_time_cost:
            return_squid_ip = squid_public_ip
    return {'available_squid_ip': return_squid_ip, 'time_cost': test_time_cost, 'squid_type': squid_type}

def getSortedSquids(squid_dict, current_list, target_num, frontier_format):
    return_proxy = ''
    return_squid = copy.deepcopy(current_list)
    if len(squid_dict.keys()) > 0:
        for item in sorted(squid_dict.items()):
            if len(return_squid) < target_num:
                return_squid.append(item[1][0])
                if frontier_format:
                    return_proxy += '(proxyurl=http://%s:%s)' % (item[1][0], item[1][1])
                else:
                    return_proxy += 'http://%s:%s;' % (item[1][0], item[1][1])
            else:
                break
    return return_proxy, return_squid    


get_args()


    # reads data from a URL from a Shoal server, parses it, prepares a list of proxies
    # stored in cvmfs_http_proxy, executes command to set CVMFS proxies if dump is not specified.


## try to load default value from env variable
try:
    syslog.syslog(syslog.LOG_INFO, "Checking Enviroment Variable for default proxy")
    env_proxy = os.environ['HTTP_PROXY']
    syslog.syslog(
        syslog.LOG_INFO,
        "Enviroment Variable located, adding: %s as default proxy" % env_proxy)
    if env_proxy != "":
        env_proxy += ";"

except KeyError as e:
    syslog.syslog(syslog.LOG_ERR, "No HTTP_PROXY enviroment variable found")

## read server data (if it can be read) into a dictionary called data
try:
    #if there is a bad proxy set we will never reach shoal-server
    #this goes direct to avoid any bad configuration
    proxy_handler = UrlRequest.ProxyHandler({})
    opener = UrlRequest.build_opener(proxy_handler)
    if server[-1] == "/":
        server_url = server + "%s/" % (numsquids + 5)
    else:
        server_url = server + "/%s/" % (numsquids + 5)
    f = opener.open(server_url, timeout=5)
    # data = json.loads(f.read())
    data = parseServerData(f.read())
    syslog.syslog(syslog.LOG_DEBUG, "Got data from %s" % server)
except (UrlError, ValueError) as e:
    # This is where the client now exits if it can't reach shoal, might be worth
    # refactoring instead of injecting code here to reuse the proceeding code.
    #checkEnvVariable()
    #checkConfig()
    syslog.syslog(syslog.LOG_ERR, "Unable to reach shoal-server, reverting to defaults")
    #sys.exit(1)


if not frontier:
    q = Queue()
    p = Process(target=get_local_squid, args=(q,timeout))
    p.start()
    try:
       local_squid = q.get(block=True, timeout=timeout)
    except Q.Empty:
       #continue
       #print("queue timeout !")
       pass
    p.join()

# If the shoal_server was reachable
## iterate through the data dict and use all hostname and squid_port keys
## to create addresses for squids in best_http_proxy
# syslog.syslog(syslog.LOG_INFO, "Received data from server, processing.")

if data:
    for squid_key in list(data):
        squid = data[squid_key]
        try:
            if not local_squid or not squid['private_ip'] or squid['private_ip'] != local_squid.split(':')[0]:
                # test the squid received from shoal server
                available_squid = getAvailableIp(squid)
                available_squid_ip = available_squid['available_squid_ip']
                if available_squid_ip:
                    time_cost = available_squid['time_cost']
                    if available_squid['squid_type'] == 'private':
                        server_available_private[time_cost] = (available_squid_ip, squid['squid_port'])
                    else:
                        server_available_public[time_cost] = (available_squid_ip, squid['squid_port'])
        except KeyError as e:
            syslog.syslog(
                syslog.LOG_ERR,
                "The data returned from '%s' was missing the key: %s. "
                "Please ensure the server is running the latest version "
                "of Shoal-Server." % (server, e))
    
    new_proxy, new_squids = getSortedSquids(server_available_private, server_squids, numsquids, frontier)
    best_http_proxy += new_proxy
    server_squids = new_squids
    
    new_proxy, new_squids = getSortedSquids(server_available_public, server_squids, numsquids, frontier)
    best_http_proxy += new_proxy
    server_squids = new_squids
    
# remove duplicate default squid
default_http_list = default_http_proxy.split(';')
temp_default_http = ''
for i, each_default in enumerate(default_http_list):
    each_default_ip = each_default.split('//')[1].split(':')[0] if len(each_default.split('//')) > 1 else each_default.split('//')[0].split(':')[0]
    if (not local_squid or each_default_ip not in local_squid) and each_default_ip not in server_squids:            
        addon = '' if i == len(default_http_list) - 1 else ';'
        temp_default_http += each_default + addon
default_http_proxy = temp_default_http

#need to reformat default string for frontier nodes
if frontier:
    new_defaults = ''
    tmpproxies = default_http_proxy.split(';')
    for proxy in tmpproxies:
        if proxy != 'DIRECT':
            new_defaults += '(proxyurl=%s)' % proxy
    default_http_proxy = new_defaults

if local_squid:
    best_http_proxy = "http://" + local_squid + ";" + best_http_proxy
cvmfs_http_proxy = "\"" + best_http_proxy + default_http_proxy + "\""
syslog.syslog(syslog.LOG_DEBUG, "Data parsing complete.")

if dump:
    if not frontier:
        cvmfs_http_proxy = "\"" + best_http_proxy + env_proxy + default_http_proxy + "\""
    syslog.syslog(syslog.LOG_INFO, "Dumping proxy string")
    print(cvmfs_http_proxy)

elif frontier:
    #retrieve frontier env variable and insert new squids
    try:
        frontier_var = os.getenv("FRONTIER_SERVER", "")
    except:
        syslog.syslog(
            syslog.LOG_ERR,
            "FRONTIER_SERVER env variable unset, unable to run with --frontier option")
        sys.exit(1)
    (frontier_output, replacements) = re.subn(r'(\(.*?\))(\(proxyurl.*?\))',
                                                r'\1' + best_http_proxy + default_http_proxy + r'\2',
                                                frontier_var, 1)
    # if no replacmeents, then there was no proxyurl to begin with
    if replacements == 0:
        frontier_output = frontier_output + best_http_proxy + default_http_proxy
    print(frontier_output)

else:
    # include the case that the client is unable to connect to the shoal_server
    cvmfs_http_proxy = best_http_proxy + env_proxy + default_http_proxy
    syslog.syslog(syslog.LOG_DEBUG, "Data parsing complete.")
    syslog.syslog(syslog.LOG_INFO, "Setting %s as proxy" % cvmfs_http_proxy)

    syslog.syslog(syslog.LOG_DEBUG, "Executing 'cvmfs_talk proxy set %s'" % cvmfs_http_proxy)
    try:
        p = Popen(["cvmfs_talk", "proxy", "set", cvmfs_http_proxy], stdout=PIPE, stderr=PIPE)
        output, errors = p.communicate()
        if errors:
            syslog.syslog(syslog.LOG_ERR, errors)
        if output:
            syslog.syslog(syslog.LOG_DEBUG, output)
        if p.returncode:
            syslog.syslog(
                syslog.LOG_ERR,
                "WARNING: CVMFS proxies may not have been set correctly for all repos")
            sys.exit(p.returncode)
        else:
            syslog.syslog(syslog.LOG_INFO, "CVMFS proxies set to %s" % cvmfs_http_proxy)
    except OSError as e:
        syslog.syslog(syslog.LOG_ERR, "Could not execute cvmfs_talk: %s" % e.strerror)
        sys.exit(e.errno)

