#!/usr/bin/env python
import sys
import json
import os
import argparse
from pyfiglet import Figlet

class CLI:
    def __init__(self):
        self.input_path = None
        self.entry_point = None
        self.num_process = None

    def draw(self):
        print(Figlet(font="slant").renderText("FastEstimator"))

    def run(self):
        parser = argparse.ArgumentParser()
        parser.add_argument('--entry_point', type=str, help='Enter path to the model python file')
        parser.add_argument('--hyperparameters', type=str, help='Enter path to the hyperparameters JSON file')
        parser.add_argument('--input', type=str, help='Enter the path where tfrecord is saved or will be saved')
        parser.add_argument('--num_process', type=int, help='Number of gpu to use', default=1)

        args, unknown = parser.parse_known_args()
        self.input_path = args.input
        self.entry_point = args.entry_point
        self.num_process = args.num_process

        if self.entry_point is None:
            raise ValueError("the entry point file must be specified using --entry_point")

        hyperparameters = None
        if args.hyperparameters:
            hyperparameters = os.path.abspath(args.hyperparameters)
            hyperparameters = json.load(open(hyperparameters, 'r'))
        else:
            hyperparameters = {}

        if len(unknown) > 0:
            assert hyperparameters == {}, "json file and arguments shouldn't both be used"
            hypers = self._cli_parser(unknown)
            for key in hypers:
                exec("hyperparameters[key]=" + hypers[key])
        self.train(hyperparameters)

    def train(self, hyperparameters):
        module_name = os.path.splitext(os.path.basename(self.entry_point))[0]
        dir_name = os.path.abspath(os.path.dirname(self.entry_point))
        sys.path.insert(0, dir_name)
        spec_module = __import__(module_name, globals(), locals(), ["get_estimator"])
        estimator = spec_module.get_estimator(**hyperparameters)
        estimator.num_process = self.num_process
        estimator.fit(inputs=self.input_path)

    def _is_float(self, value):
        try:
            v = float(value)
            return True
        except ValueError:
            return False

    def _is_int(self, value):
        try:
            v = int(value)
            return True
        except ValueError:
            return False

    def _string_builder(self, value):
        if value == 'None' or value == 'True' or value == 'False':
            return value
        else:
            return '\"' + value + '\"'

    def _is_data_structure(self, value):
        v = value[0]
        return v == '[' or v == '('

    def _find_delimiter(self, value):
        comma = value.find(',')
        sq_br = value.find(']')
        l = []
        if comma > 0:
            l.append(comma)
        if sq_br > 0:
            l.append(sq_br)
        return min(l)

    def _ds_builder(self, value):
        new_value = ''
        new_value += value[0]
        i = 1
        output = ''
        while i < len(value)-1:
            asc = ord(value[i])
            if (asc >= 64 and asc <= 90) or (asc>=97 and asc <= 122) or (asc == 95) or (asc == 45):
                idx = self._find_delimiter(value[i:])
                new_value += self._string_builder(value[i:i+idx])
                i += idx
            else:
                new_value += value[i]
                i += 1
        new_value += value[-1]
        return new_value

    def _parse(self, value):
        value = value.replace(' ', '')
        if self._is_int(value):
            return(value)
        elif self._is_float(value):
            return(value)
        elif self._is_data_structure(value):
            return self._ds_builder(value)
        else:
            return self._string_builder(value)

    def _cli_parser(self, args):
        to_parse = " ".join(args)
        hyperparameters = {}
        while len(to_parse) > 0:
            idx1 = to_parse.find('--')
            idx2 = to_parse[idx1+2:].find('--')
            if idx2 == -1:
                idx2 = len(to_parse) - 2
            sub_str_args = to_parse[idx1:idx2+2].split()
            key = sub_str_args[0].strip('-')
            val = self._parse(" ".join(sub_str_args[1:]))
            hyperparameters[key] = val
            cnt = idx2 + 2
            to_parse = to_parse[cnt:]
        return hyperparameters

def run():
    cli = CLI()
    cli.draw()
    cli.run()

if __name__ == '__main__':
    run()