#!/usr/bin/env python3

import contextlib
import logging

from eflomal import Aligner, sentences_from_joint_file

import sys, argparse, os


logger = logging.getLogger(__name__)


def main():
    parser = argparse.ArgumentParser(
        description='eflomal: efficient low-memory aligner')
    parser.add_argument(
        '-v', '--verbose', dest='verbose',
        action='store_true', help='Enable verbose output')
    parser.add_argument(
        '--debug', dest='debug',
        action='store_true', help='Enable gdb debugging of eflomal binary')
    parser.add_argument(
        '--overwrite', dest='overwrite',
        action='store_true', help='Overwrite existing output files')
    parser.add_argument(
        '--null-prior', dest='null_prior', default=0.2, metavar='X',
        type=float, help='Prior probability of NULL alignment')
    parser.add_argument(
        '-m', '--model', dest='model', default=3, metavar='N',
        type=int, help='Model (1 = IBM1, 2 = IBM1+HMM, 3 = IBM1+HMM+fertility)')
    parser.add_argument(
        '-M', '--score-model', dest='score_model', default=0, metavar='N',
        type=int, help='Model used for sentence scoring '
                       '(1 = IBM1, 2 = IBM1+HMM, 3 = IBM1+HMM+fertility)')
    parser.add_argument(
        '--source-prefix', dest='source_prefix_len', default=0, metavar='N',
        type=int, help='Length of prefix for stemming (source)')
    parser.add_argument(
        '--source-suffix', dest='source_suffix_len', default=0, metavar='N',
        type=int, help='Length of suffix for stemming (source)')
    parser.add_argument(
        '--target-prefix', dest='target_prefix_len', default=0, metavar='N',
        type=int, help='Length of prefix for stemming (target)')
    parser.add_argument(
        '--target-suffix', dest='target_suffix_len', default=0, metavar='N',
        type=int, help='Length of suffix for stemming (target)')
    parser.add_argument(
        '-l', '--length', dest='length', default=1.0, metavar='X',
        type=float, help='Relative number of sampling iterations')
    parser.add_argument(
        '-1', '--ibm1-iters', dest='iters1', default=None, metavar='X',
        type=int, help='Number of IBM1 iterations (overrides --length)')
    parser.add_argument(
        '-2', '--hmm-iters', dest='iters2', default=None, metavar='X',
        type=int, help='Number of HMM iterations (overrides --length)')
    parser.add_argument(
        '-3', '--fert-iters', dest='iters3', default=None, metavar='X',
        type=int,
        help='Number of HMM+fertility iterations (overrides --length)')
    parser.add_argument(
        '--n-samplers', dest='n_samplers', default=3, metavar='X',
        type=int, help='Number of independent samplers to run')
    parser.add_argument(
        '-s', '--source', dest='source_filename', type=str, metavar='filename',
        help='Source text filename')
    parser.add_argument(
        '-t', '--target', dest='target_filename', type=str, metavar='filename',
        help='Target text filename')
    parser.add_argument(
        '-i', '--input', dest='joint_filename', type=str, metavar='filename',
        help='fast_align style ||| separated file')
    parser.add_argument(
        '-f', '--forward-links', dest='links_filename_fwd', type=str,
        metavar='filename',
        help='Filename to write forward direction alignments to')
    parser.add_argument(
        '-r', '--reverse-links', dest='links_filename_rev', type=str,
        metavar='filename',
        help='Filename to write reverse direction alignments to')
    parser.add_argument(
        '-F', '--forward-scores', dest='scores_filename_fwd', type=str,
        metavar='filename',
        help='Filename to write alignment scores to (generation '
            'probability of target sentences)')
    parser.add_argument(
        '-R', '--reverse-scores', dest='scores_filename_rev', type=str,
        metavar='filename',
        help='Filename to write alignment scores to (generation '
            'probability of source sentences)')
    parser.add_argument(
        '-p', '--priors', dest='priors_filename', type=str, metavar='filename',
        help='File to read priors from')

    args = parser.parse_args()
    logging.basicConfig(level=logging.INFO if args.verbose else logging.ERROR)

    if not (args.joint_filename or (args.source_filename and
        args.target_filename)):
        logger.error('need to specify either -s and -t, or -i')
        sys.exit(1)

    for filename in ((args.joint_filename,) if args.joint_filename else
                     (args.source_filename, args.target_filename)):
        if not os.path.exists(filename):
            logger.error('input file %s does not exist!', filename)
            sys.exit(1)

    for filename in (args.links_filename_fwd, args.links_filename_rev):
        if (not args.overwrite) and (filename is not None) \
                and os.path.exists(filename):
            logger.error('output file %s exists, will not overwrite!',
                         filename)
            sys.exit(1)

    iters = (args.iters1, args.iters2, args.iters3)
    if any(x is None for x in iters[:args.model]):
        iters = None

    aligner = Aligner(
        model=args.model, score_model=args.score_model,
        n_iterations=iters, n_samplers=args.n_samplers,
        rel_iterations=args.length, null_prior=args.null_prior,
        source_prefix_len=args.source_prefix_len,
        source_suffix_len=args.source_suffix_len,
        target_prefix_len=args.target_prefix_len,
        target_suffix_len=args.target_suffix_len)

    # Stack for automatic closing of file objects
    with contextlib.ExitStack() as stack:
        if args.priors_filename:
            priors_input = stack.enter_context(
                open(args.priors_filename, 'r', encoding='utf-8'))
        else:
            priors_input = None
        if args.joint_filename:
            logger.info('Reading source/target sentences from %s...',
                        args.joint_filename)
            src_in_f = stack.enter_context(
                open(args.joint_filename, 'r', encoding='utf-8'))
            src_input = sentences_from_joint_file(src_in_f, 0)
            trg_in_f = stack.enter_context(
                open(args.joint_filename, 'r', encoding='utf-8'))
            trg_input = sentences_from_joint_file(trg_in_f, 1)
        else:
            src_input = stack.enter_context(
                open(args.source_filename, 'r', encoding='utf-8'))
            trg_input = stack.enter_context(
                open(args.target_filename, 'r', encoding='utf-8'))

        aligner.align(src_input, trg_input,
                      links_filename_fwd=args.links_filename_fwd,
                      links_filename_rev=args.links_filename_rev,
                      scores_filename_fwd=args.scores_filename_fwd,
                      scores_filename_rev=args.scores_filename_rev,
                      priors_input=priors_input,
                      quiet=not args.verbose, use_gdb=args.debug)


if __name__ == '__main__':
    main()
