#!/usr/bin/env python3

import argparse, logging, os.path, sys

from eflomal import calculate_priors, sentences_from_joint_file, write_priors


logger = logging.getLogger(__name__)


def main():
    parser = argparse.ArgumentParser(
            description='Tool for creating priors from alignments')
    parser.add_argument(
        '-v', '--verbose', dest='verbose',
        action='store_true', help='Enable verbose output')
    parser.add_argument(
            '-s', '--source', dest='source_filename', type=str,
            metavar='filename',
            help='Source text file')
    parser.add_argument(
            '-t', '--target', dest='target_filename', type=str,
            metavar='filename',
            help='Target text file')
    parser.add_argument(
            '-i', '--input', dest='joint_filename', type=str,
            metavar='filename',
            help='fast_align style ||| separated file')
    parser.add_argument(
            '-f', '--forward-alignments', dest='forward_alignments_filename',
            type=str, metavar='filename', required=True,
            help='File containing forward (or symmetrized) alignments, '
                 'may be same file as --reverse-alignments')
    parser.add_argument(
            '-r', '--reverse-alignments', dest='reverse_alignments_filename',
            type=str, metavar='filename', required=True,
            help='File containing reverse (or symmetrized) alignments, '
                 'may be same file as --forward-alignments')
    parser.add_argument(
            '-p', '--priors', dest='priors_filename', type=str,
            metavar='filename', default='-',
            help='File to write priors to (for use with align.py --priors)')
    args = parser.parse_args()
    args = parser.parse_args()
    logging.basicConfig(level=logging.INFO if args.verbose else logging.ERROR)

    if not (args.joint_filename or
            (args.source_filename and args.target_filename)):
        logger.error('need to specify either -s and -t, or -i')
        sys.exit(1)

    for filename in ((args.joint_filename,) if args.joint_filename else
                     (args.source_filename, args.target_filename)):
        if not os.path.exists(filename):
            logger.error('input file %s does not exist!', filename)
            sys.exit(1)

    if args.joint_filename:
        with open(args.forward_alignments_filename, 'r', encoding='utf-8') as fwdf, \
             open(args.reverse_alignments_filename, 'r', encoding='utf-8') as revf, \
             open(args.joint_filename, 'r', encoding='utf-8') as jointf:
            priors_list, hmmf_priors, hmmr_priors, ferf_priors, ferr_priors = \
                calculate_priors(*zip(*sentences_from_joint_file(jointf)), fwdf, revf)
    else:
        with open(args.forward_alignments_filename, 'r', encoding='utf-8') as fwdf, \
             open(args.reverse_alignments_filename, 'r', encoding='utf-8') as revf, \
             open(args.source_filename, 'r', encoding='utf-8') as srcf, \
             open(args.target_filename, 'r', encoding='utf-8') as trgf:
            priors_list, hmmf_priors, hmmr_priors, ferf_priors, ferr_priors = \
                calculate_priors(srcf, trgf, fwdf, revf)

    priorsf = sys.stdout if args.priors_filename == '-' else \
        open(args.priors_filename, 'w', encoding='utf-8')
    write_priors(priorsf, priors_list, hmmf_priors, hmmr_priors,
                 ferf_priors, ferr_priors)
    priorsf.close()


if __name__ == '__main__': main()
