#!/usr/bin/env python
import argparse
import os
import re
import readline
import select
import signal
import socket
import sys
import time
import threading

try:
    input = raw_input
except NameError:
    pass


def parse_pid(value, regex=re.compile(r'(/tmp/manhole-)?(?P<pid>\d+)')):
    match = regex.match(value)
    if not match:
        raise argparse.ArgumentTypeError("PID must be in one of these forms: 1234 or /tmp/manhole-1234")

    return int(match.group('pid'))

parser = argparse.ArgumentParser(description='Connect to a manhole.')
parser.add_argument('pid', metavar='PID', type=parse_pid,  # nargs='?',
                    help='A numerical process id, or a path in the form: /tmp/manhole-1234')
parser.add_argument('-t', '--timeout', dest='timeout', default=1, type=float,
                    help='Timeout to use. Default: %(default)s seconds.')
group = parser.add_mutually_exclusive_group()
group.add_argument('-1', '-USR1', dest='signal', action='store_const', const=signal.SIGUSR1,
                   help='Send USR1 (%(const)s) to the process before connecting.')
group.add_argument('-2', '-USR2', dest='signal', action='store_const', const=signal.SIGUSR2,
                   help='Send USR2 (%(const)s) to the process before connecting.')


class ConnectionHandler(threading.Thread):
    def __init__(self, sock, read_fd):
        super(ConnectionHandler, self).__init__()
        self.sock = sock
        self.read_fd = read_fd
        self.should_run = True

    def run(self):
        sock = self.sock
        conn_fd = sock.fileno()
        read_fd = self.read_fd

        poll.register(read_fd, select.POLLIN | select.POLLPRI | select.POLLERR | select.POLLHUP)
        poll.register(conn_fd, select.POLLIN | select.POLLPRI | select.POLLERR | select.POLLHUP)

        while self.should_run:
            for fd, _ in poll.poll(args.timeout):
                if fd == conn_fd:
                    data = sock.recv(1024)
                    sys.stdout.write(data.decode('utf8'))
                    sys.stdout.flush()
                    readline.redisplay()
                elif fd == read_fd:
                    data = os.read(read_fd, 1024)
                    sock.sendall(data)
                else:
                    raise RuntimeError("Unknown FD %s" % fd)

if __name__ == "__main__":
    args = parser.parse_args()

    histfile = os.path.join(os.path.expanduser("~"), ".manhole_history")
    try:
        readline.read_history_file(histfile)
    except IOError:
        pass
    import atexit
    atexit.register(readline.write_history_file, histfile)
    del histfile

    if args.signal:
        os.kill(args.pid, args.signal)

    start = time.time()
    uds_path = '/tmp/manhole-%s' % args.pid
    sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
    sock.settimeout(args.timeout)
    while time.time() - start < args.timeout:
        try:
            sock.connect(uds_path)
        except Exception as exc:
            print("Failed to connect to %r: %r" % (uds_path, exc))
        else:
            break

    poll = select.poll()
    read_fd, write_fd = os.pipe()

    thread = ConnectionHandler(sock, read_fd)
    thread.start()

    try:
        while thread.is_alive():
            try:
                data = input().encode('utf8')
            except EOFError:
                break
            os.write(write_fd, data)
            os.write(write_fd, b'\n')
    except KeyboardInterrupt:
        pass
    finally:
        thread.should_run = False
        thread.join()
