#!/usr/bin/env python
"""Data generation for Galaxy2Galaxy.

This script is used to generate data to train your models
for a number problems for which open-source data is available.

For example, to generate data for MNIST run this:

t2t-datagen \
  --problem=image_mnist \
  --data_dir=~/t2t_data \
  --tmp_dir=~/t2t_data/tmp
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensor2tensor.bin.t2t_datagen import *
from galaxy2galaxy import problems
from galaxy2galaxy.utils import registry

import tensorflow as tf

def main(argv):
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

  # Calculate the list of problems to generate.
  problems = sorted(registry.list_problems())
  for exclude in FLAGS.exclude_problems.split(","):
    if exclude:
      problems = [p for p in problems if exclude not in p]
  if FLAGS.problem and FLAGS.problem[-1] == "*":
    problems = [p for p in problems if p.startswith(FLAGS.problem[:-1])]
  elif FLAGS.problem and "," in FLAGS.problem:
    problems = [p for p in problems if p in FLAGS.problem.split(",")]
  elif FLAGS.problem:
    problems = [p for p in problems if p == FLAGS.problem]
  else:
    problems = []

  # Remove TIMIT if paths are not given.
  if getattr(FLAGS, "timit_paths", None):
    problems = [p for p in problems if "timit" not in p]
  # Remove parsing if paths are not given.
  if getattr(FLAGS, "parsing_path", None):
    problems = [p for p in problems if "parsing_english_ptb" not in p]

  if not problems:
    problems_str = "\n  * ".join(sorted(registry.list_problems()))
    error_msg = ("You must specify one of the supported problems to "
                 "generate data for:\n  * " + problems_str + "\n")
    error_msg += ("TIMIT and parsing need data_sets specified with "
                  "--timit_paths and --parsing_path.")
    raise ValueError(error_msg)

  if not FLAGS.data_dir:
    FLAGS.data_dir = tempfile.gettempdir()
    tf.logging.warning(
        "It is strongly recommended to specify --data_dir. "
        "Data will be written to default data_dir=%s.", FLAGS.data_dir)
  FLAGS.data_dir = os.path.expanduser(FLAGS.data_dir)
  tf.gfile.MakeDirs(FLAGS.data_dir)

  tf.logging.info("Generating problems:\n%s" %
                  registry.display_list_by_prefix(problems, starting_spaces=4))
  if FLAGS.only_list:
    return
  for problem in problems:
    set_random_seed()

    if problem in registry.list_base_problems():
      generate_data_for_registered_problem(problem)
    elif problem in registry.list_env_problems():
      generate_data_for_env_problem(problem)
    else:
      tf.logging.error("Problem %s is not a supported problem for datagen.",
                       problem)

if __name__ == "__main__":
  tf.logging.set_verbosity(tf.logging.INFO)
  tf.app.run()
