#!/usr/bin/env python
"""Export a trained model as tf hub modules."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from tensor2tensor.bin import t2t_trainer
from tensor2tensor.utils import decoding
from tensor2tensor.utils import t2t_model
from tensor2tensor.utils import trainer_lib
from tensor2tensor.utils import usr_dir

from galaxy2galaxy import models
from galaxy2galaxy import problems

import tensorflow as tf
import tensorflow_hub as hub

FLAGS = tf.flags.FLAGS

tf.flags.DEFINE_string(
    "export_dir", None, "Directory, where export model should be stored."
    "If None, the model will be stored in subdirectory "
    "where checkpoints are: --output_dir")

tf.flags.DEFINE_string(
    "checkpoint_path", None, "Which checkpoint to export."
    "If None, we will use the latest checkpoint stored in the directory "
    "specified by --output_dir")

def _get_hparams_path():
  """Get hyper-parameters file path."""
  hparams_path = None
  if FLAGS.output_dir:
    hparams_path = os.path.join(FLAGS.output_dir, "hparams.json")
  elif FLAGS.checkpoint_path:  # Infer hparams.json from checkpoint path
    hparams_path = os.path.join(
        os.path.dirname(FLAGS.checkpoint_path), "hparams.json")

  # Check if hparams_path really exists
  if hparams_path:
    if tf.gfile.Exists(hparams_path):
      tf.logging.info("hparams file %s exists", hparams_path)
    else:
      tf.logging.info("hparams file %s does not exist", hparams_path)
      hparams_path = None

  # Can't find hparams_path
  if not hparams_path:
    tf.logging.warning(
        "--output_dir not specified or file hparams.json does not exists. "
        "Hyper-parameters will be infered from --hparams_set and "
        "--hparams only. These may not match training time hyper-parameters.")

  return hparams_path


def create_estimator(run_config, hparams):
  return trainer_lib.create_estimator(
      FLAGS.model,
      hparams,
      run_config,
      decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
      use_tpu=FLAGS.use_tpu)


def create_hparams():
  """Create hyper-parameters object."""
  return trainer_lib.create_hparams(
      FLAGS.hparams_set,
      FLAGS.hparams,
      data_dir=os.path.expanduser(FLAGS.data_dir),
      problem_name=FLAGS.problem,
      hparams_path=_get_hparams_path())

def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

  if FLAGS.checkpoint_path:
    checkpoint_path = FLAGS.checkpoint_path
    ckpt_dir = os.path.dirname(checkpoint_path)
  else:
    ckpt_dir = os.path.expanduser(FLAGS.output_dir)
    checkpoint_path = tf.train.latest_checkpoint(ckpt_dir)

  hparams = create_hparams()
  hparams.no_data_parallelism = True  # To clear the devices
  problem = hparams.problem
  decode_hparams = decoding.decode_hparams(FLAGS.decode_hparams)

  export_dir = FLAGS.export_dir or os.path.join(ckpt_dir, "export")

  run_config = t2t_trainer.create_run_config(hparams)

  estimator = create_estimator(run_config, hparams)

  # Use tf hub to export any module that has been registered
  exporter = hub.LatestModuleExporter("tf_hub",
        lambda: problem.serving_input_fn(hparams, decode_hparams, FLAGS.use_tpu))

  exporter.export(
      estimator,
      export_dir,
      checkpoint_path=checkpoint_path)


if __name__ == "__main__":
  tf.logging.set_verbosity(tf.logging.INFO)
  tf.app.run()
