#!/usr/bin/env python
import sys
import select
import socket
import logging
import argparse
import subprocess

from six.moves import xrange
from tfmesos import cluster

parser = argparse.ArgumentParser()
parser.add_argument('-w', '--nworker', type=int, required=True)
parser.add_argument('-s', '--nserver', type=int, required=True)
parser.add_argument('-m', '--master', type=str)
parser.add_argument('-n', '--name', type=str, help='Mesos framework name to register as.')
parser.add_argument(
    '-C', '--containerizer_type', choices=['MESOS', 'DOCKER'],
    nargs='?', type=lambda s: s.upper()
)
parser.add_argument('-f', '--force_pull_image', action="store_true")
parser.add_argument('-Cw', '--worker-cpus', type=float, default=1.0)
parser.add_argument('-Gw', '--worker-gpus', type=int, default=0)
parser.add_argument('-Mw', '--worker-mem', type=float, default=1024.0)
parser.add_argument('-Cs', '--server-cpus', type=float, default=1.0)
parser.add_argument('-Gs', '--server-gpus', type=int, default=0)
parser.add_argument('-Ms', '--server-mem', type=float, default=1024.0)
parser.add_argument('-v', '--verbose', action='store_true')
parser.add_argument('-V', '--volume', type=str, action='append')
parser.add_argument('-r', '--role', type=str, action='store')
parser.add_argument('cmd', type=str)
parser.add_argument('args', nargs='*')

args = parser.parse_args()
cmd = [args.cmd] + args.args
cmd = ' '.join(cmd)

volumes = (dict(v.split(':', 1) for v in args.volume)
           if args.volume is not None else {})

server_name = 'ps'
nserver = args.nserver
worker_name = 'worker'
nworker = args.nworker

extra_kw = {}
if args.containerizer_type:
    extra_kw['containerizer_type'] = args.containerizer_type

if args.force_pull_image:
    extra_kw['force_pull_image'] = args.force_pull_image

jobs_def = [
    dict(
        name=server_name,
        num=nserver,
        cpus=args.server_cpus,
        gpus=args.server_gpus,
        mem=args.server_mem,
        cmd=cmd,
    ),
    dict(
        name=worker_name,
        num=nworker,
        cpus=args.worker_cpus,
        gpus=args.worker_gpus,
        mem=args.worker_mem,
        cmd=cmd,
    ),
]


logging.basicConfig(
    format='[tfrun] %(threadName)s %(asctime)-15s %(message)s',
    level=args.verbose and logging.DEBUG or logging.INFO,
)

lfd = socket.socket()
lfd.bind(('', 0))
addr = (socket.gethostname(), lfd.getsockname()[1])
lfd.listen(10)
fds = [lfd]
forward_addresses = {'/job:%s/task:%s' % (worker_name, 0): addr}


with cluster(jobs_def, role=args.role, master=args.master,
             quiet=not args.verbose, volumes=volumes, name=args.name,
             forward_addresses=forward_addresses,
             **extra_kw) as cluster:
    while not cluster.finished():
        readables = select.select(fds, [], [], 0.1)[0]
        for fd in readables:
            if fd == lfd:
                fds.append(lfd.accept()[0])
            else:
                r = fd.recv(4096)
                if not r:
                    fds.remove(fd)
                    fd.close()
                else:
                    sys.stdout.write(r)

    for fd in fds:
        fd.close()
