#!/usr/bin/env python
import os
import argparse
import sys

class CLI:
    def __init__(self, arguments):
        self.arguments = arguments

    def run(self):
        mode = self.arguments[0]
        assert mode in ["train", "predict"], "the mode must be specified either train or predict"
        getattr(self, mode)()
    
    def train(self):
        parser = argparse.ArgumentParser()
        parser.add_argument('--num_process', type=int, help='Number of parallel training process', default=1)
        args, _ = parser.parse_known_args()
        num_process = args.num_process
        cmd = self.get_train_cmd(self.arguments[1:])
        if num_process >1:
            cmd = "mpirun -np %d -H localhost:%d --allow-run-as-root " % (num_process, num_process)+ cmd
        os.system(cmd)

    def predict(self):
        print("prediction not yet supported, but coming soon...!")
        sys.exit(-1)

    def get_train_cmd(self, arguments):
        cmd = "fastestimator_train "
        for arg in arguments:
            cmd += " "
            cmd += str(arg)
        return cmd

if __name__ == '__main__':
    arguments = sys.argv[1:]
    cli = CLI(arguments)
    cli.run()

