#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import sys
import argparse
from theanolm.commands import *
from theanolm.exceptions import *

def main():
    parser = argparse.ArgumentParser(prog='theanolm')
    subparsers = parser.add_subparsers(
        title='commands',
        help='selects the command to perform ("theanolm command --help" '
             'displays help for the specific command)')

    train_parser = subparsers.add_parser(
        'train', help='train a model')
    train.add_arguments(train_parser)
    train_parser.set_defaults(command_function=train.train)

    score_parser = subparsers.add_parser(
        'score', help='score text or n-best lists using a model')
    score.add_arguments(score_parser)
    score_parser.set_defaults(command_function=score.score)

    decode_parser = subparsers.add_parser(
        'decode', help='decode a word lattice using a model')
    decode.add_arguments(decode_parser)
    decode_parser.set_defaults(command_function=decode.decode)

    sample_parser = subparsers.add_parser(
        'sample', help='generate text using a model')
    sample.add_arguments(sample_parser)
    sample_parser.set_defaults(command_function=sample.sample)

    version_parser = subparsers.add_parser(
        'version', help='display the version number')
    version_parser.set_defaults(command_function=version.version)

    args = parser.parse_args()

    if hasattr(args, 'command_function'):
        try:
            args.command_function(args)
        except NumberError as e:
            print('A numerical error has occurred. This may happen e.g. when '
                  'network parameters go to infinity. If this happens during '
                  'training, using a smaller learning rate usually helps. '
                  'Another possibility is to use gradient normalization. If '
                  'this happens during inference, there is probably something '
                  'wrong in the model. The error message was: "' + str(e) + '"')
            sys.exit(2)
    else:
        parser.print_help()

if __name__ == "__main__":
    main()
