#!/usr/bin/env python
import socket
import argparse
import subprocess

from six.moves import xrange
from tfmesos import cluster

parser = argparse.ArgumentParser()
parser.add_argument('-w', '--nworker', type=int, required=True)
parser.add_argument('-s', '--nserver', type=int, required=True)
parser.add_argument('-m', '--master', type=str)
parser.add_argument('-Cw', '--worker-cpus', type=float, default=1.0)
parser.add_argument('-Gw', '--worker-gpus', type=int, default=0)
parser.add_argument('-Mw', '--worker-mem', type=float, default=1024.0)
parser.add_argument('-Cs', '--server-cpus', type=float, default=1.0)
parser.add_argument('-Gs', '--server-gpus', type=int, default=0)
parser.add_argument('-Ms', '--server-mem', type=float, default=1024.0)
parser.add_argument('-v', '--verbose', action='store_true')
parser.add_argument('-V', '--volume', type=str, action='append')
parser.add_argument('cmd', type=str)
parser.add_argument('args', nargs='*')

args = parser.parse_args()
cmd = [args.cmd] + args.args
cmd = ' '.join(cmd)

volumes = dict(v.split(':', 1) for v in args.volume) if args.volume is not None else None

server_name = 'ps'
nserver = args.nserver
worker_name = 'worker'
nworker = args.nworker

jobs_def = [
    dict(
        name=server_name,
        num=nserver,
        cpus=args.server_cpus,
        gpus=args.server_gpus,
        mem=args.server_mem,
        cmd=cmd,
    ),
    dict(
        name=worker_name,
        num=nworker,
        cpus=args.worker_cpus,
        gpus=args.worker_gpus,
        mem=args.worker_mem,
        cmd=cmd,
        start=1
    ),
]


def get_addr(targets, name, i):
    prefix_len = len('grpc://')
    url = targets['/job:%s/task:%s' % (name, i)]
    return url[prefix_len:]


def get_port():
    s = socket.socket()
    s.bind(('', 0))
    port = s.getsockname()[1]
    s.close()
    return port

host = socket.gethostname()
port = get_port()
addr = '%s:%s' % (host, port)

with cluster(jobs_def, master=args.master, quiet=not args.verbose,
             volumes=volumes, local_task=(worker_name, addr)) as targets:
    ps_hosts = ','.join(
        get_addr(targets, server_name, i) for i in xrange(nserver)
    )
    worker_hosts = ','.join(
        [addr] +
        [get_addr(targets, worker_name, i) for i in xrange(1, nworker)]
    )
    cmd = cmd.format(
        ps_hosts=ps_hosts, worker_hosts=worker_hosts,
        job_name=worker_name, task_index=0
    )
    subprocess.check_call(cmd, shell=True)
