#!/usr/bin/env python
# Copyright (c) 2013-2015, Yahoo Inc.
# Copyrights licensed under the Apache 2.0 License
# See the accompanying LICENSE.txt file for terms.

"""
Command line utility providing a ping like interface for pinging tcp/ssl
services.
"""
from __future__ import print_function
import ssl
import socket
import sys
import datetime
import time
import optparse
from serviceping import calc_deviation
from serviceping.commandline import parse_arguments


try:
    from urllib.parse import urlparse
except ImportError:
    from urlparse import urlparse


TIMEOUT = 10  # Set a reasonable timeout value


def exit_statistics():
    """
    Print ping exit statistics
    """
    end_time = datetime.datetime.now()
    print('\b\b--- %s ping statistics ---' % hostname)
    duration = end_time - start_time
    duration_sec = float(duration.seconds * 1000)
    duration_ms = float(duration.microseconds / 1000)
    duration = duration_sec + duration_ms
    try:
        print(
            '%d packets transmitted, %d received, %d%% packet loss, '
            'time %sms' % (
                count_sent,
                count_received,
                100 - ((float(count_received) / float(count_sent)) * 100),
                duration)
        )
    except ZeroDivisionError:
        print(
            '%d packets transmitted, %d received, %d%% packet loss, '
            'time %sms' % (
                count_sent,
                count_received,
                100,
                duration)
        )
    print(
        'rtt min/avg/max/dev = %.2f/%.2f/%.2f/%.2f ms' % (
            min_time.seconds * 1000 + float(min_time.microseconds) / 1000,
            float(avg_time) / 1000,
            max_time.seconds * 1000 + float(max_time.microseconds) / 1000,
            float(deviation)
        )
    )


def scan(host, port=80, url=None, https=False):
    """
    Run a scan of a host/port
    @param host:
    @param port:
    @param url:
    @param https:
    @return:
    """
    global TIMEOUT
    starts = {}
    ends = {}
    durations = {}
    code = 0
    length = 0
    if url:
        TIMEOUT = 1
    port = int(port)
    starts['all'] = starts['dns'] = datetime.datetime.now()
    try:
        host = socket.gethostbyname(host)
        ends['dns'] = datetime.datetime.now()
    except socket.gaierror:
        return host, \
            0, \
            'DNS Lookup failed', \
            datetime.datetime.now() - starts['dns']
    starts['connect'] = datetime.datetime.now()
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.settimeout(TIMEOUT)
    result = s.connect_ex((host, port))
    ends['connect'] = datetime.datetime.now()
    if https:
        starts['ssl'] = datetime.datetime.now()
        s = ssl.wrap_socket(s)
        ends['ssl'] = datetime.datetime.now()
    if result == 0 and url:
        # s.sendall('get %s' % url)
        starts['request'] = datetime.datetime.now()
        s.send("GET {0} HTTP/1.0\r\nHost: {1}\r\n\r\n".format(url, host).encode('ascii'))
        data = s.recv(1500).decode('ascii')
        length = len(data)
        try:
            code = data.split('\n')[0].split()[1]
        except IndexError:
            code = 0
        ends['request'] = datetime.datetime.now()
    s.close()
    ends['all'] = datetime.datetime.now()
    for d in starts.keys():
        if d in ends.keys():
            durations[d] = ends[d] - starts[d]
    if result == 0:
        return host, port, 'Open', code, durations, length
    return host, port, 'Closed', code, durations, length


if __name__ == '__main__':
    (options, command_args) = parse_arguments()

    rc = 1
    https = False
    if command_args[0].startswith('http:') or \
            command_args[0].startswith('https:'):
        urlp = urlparse(command_args[0])
        hostname = urlp.hostname
        port = urlp.port
        if command_args[0].startswith('https:'):
            https = True
        if not port:
            if urlp.scheme in ['https']:
                port = 443
            else:
                port = 80
        url = urlp.path
    else:
        args = command_args[0].split(':')
        hostname = args[0]
        try:
            port = args[1]
        except IndexError:
            port = 80
        try:
            url = args[2]
        except IndexError:
            url = None
    count_sent = count_received = 0
    max_time = min_time = datetime.timedelta(0)
    avg_time = 0
    deviation = 0
    start_time = datetime.datetime.now()
    times = []
    try:
        ip = socket.gethostbyname(hostname)
    except socket.gaierror:
        print('serviceping: unknown host %s' % hostname, file=sys.stderr)
        sys.exit(1)
    print('SERVICEPING %s:%s (%s:%s).' % (hostname, port, ip, port))
    while True:
        try:
            ip = socket.gethostbyname(hostname)
        except socket.gaierror:
            print('serviceping: unknown host %s' % hostname, file=sys.stderr)
            sys.exit(1)

        try:
            ip, port, state, code, durations, length = scan(
                hostname, port, url, https)
            count_sent += 1
            if state.lower() in ['open']:
                count_received += 1
                if durations['all'] > max_time:
                    max_time = durations['all']
                if min_time == datetime.timedelta(0) or durations['all'] < \
                        min_time:
                    min_time = durations['all']
                times.append(
                    durations['all'].seconds*1000+float(
                        durations['all'].microseconds
                    )
                )
                avg_time = sum(times) / float(len(times))
                deviation = calc_deviation(times, avg_time)
                if len(times) > 100:
                    times = times[-100:]
                if url:
                    code_string = 'response=%s' % code
                else:
                    code_string = ''
                if options.timings:
                    print(
                        '%sfrom %s:%s (%s:%s):%s' % (
                            '%d bytes ' % length if length else '',
                            hostname, port, ip, port, code_string),
                        end=" "
                    )
                    for d in ['dns', 'connect', 'ssl', 'request', 'all']:
                        if d in durations.keys():
                            print('%s=%.2fms' % (
                                d,
                                durations[d].seconds*1000 +
                                float(durations[d].microseconds)/1000
                            ), end=" ")
                else:
                    print(
                        '%sfrom %s:%s (%s:%s):%s time=%.2f ms' % (
                            '%d bytes ' % length if length else '',
                            hostname, port, ip, port, code_string,
                            float(
                                durations['all'].seconds*1000 +
                                durations['all'].microseconds
                            ) / 1000
                        ),
                        end=" ")
                print()
                rc = 0
            else:
                rc = 1
            if options.count and options.count == count_sent:
                exit_statistics()
                sys.exit(rc)
            time.sleep(options.interval)
        except KeyboardInterrupt:
            exit_statistics()
            sys.exit(rc)
