#!/usr/bin/env python


from __future__ import print_function
import argparse
import codecs
import logging
import pickle

from cort.core import corpora
from cort.core import mention_extractor
from cort.coreference import cost_functions
from cort.coreference import experiments
from cort.coreference import features
from cort.coreference import instance_extractors
from cort.util import import_helper


__author__ = 'smartschat'


def parse_args():
    parser = argparse.ArgumentParser(description='Predict coreference '
                                                 'relations.')
    parser.add_argument('-in',
                        required=True,
                        dest='input_filename',
                        help='The input file. Must follow the format of the '
                             'CoNLL shared tasks on coreference resolution '
                             '(see http://conll.cemantix.org/2012/data.html).)')
    parser.add_argument('-model',
                        required=True,
                        dest='model',
                        help='The model learned via cort-train.')
    parser.add_argument('-out',
                        dest='output_filename',
                        required=True,
                        help='The output file the predictions will be stored'
                             'in (in the CoNLL format.')
    parser.add_argument('-ante',
                        dest='ante',
                        required=True,
                        help='The file where antecedent predictions will be'
                             'stored to.')
    parser.add_argument('-extractor',
                        dest='extractor',
                        required=True,
                        help='The function to extract instances.')
    parser.add_argument('-perceptron',
                        dest='perceptron',
                        required=True,
                        help='The perceptron to use.')
    parser.add_argument('-clusterer',
                        dest='clusterer',
                        required=True,
                        help='The clusterer to use.')
    parser.add_argument('-features',
                        dest='features',
                        help='The file containing the list of features. If not'
                             'provided, defaults to a standard set of'
                             'features.')

    return parser.parse_args()

logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s %(levelname)s %(''message)s')

args = parse_args()

if args.features:
    mention_features, pairwise_features = import_helper.get_features(
        args.features)
else:
    mention_features = [
        features.fine_type,
        features.gender,
        features.number,
        features.sem_class,
        features.gr_func,
        features.head_ner,
        features.length,
        features.head,
        features.first,
        features.last,
        features.preceding_token,
        features.next_token
    ]

    pairwise_features = [
        features.exact_match,
        features.head_match,
        features.same_speaker,
        features.alias,
        features.sentence_distance,
        features.embedding,
        features.modifier
    ]

extractor = instance_extractors.InstanceExtractor(
    import_helper.import_from_path(args.extractor),
    mention_features,
    pairwise_features,
    cost_functions.null_cost
)

logging.info("Loading model.")
priors, weights = pickle.load(open(args.model, "rb"))

perceptron = import_helper.import_from_path(args.perceptron)(
    priors=priors,
    weights=weights,
    cost_scaling=0
)

logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s %(levelname)s %(''message)s')

logging.info("Reading in data.")
testing_corpus = corpora.Corpus.from_file("testing",
                                           codecs.open(args.input_filename,
                                                       "r", "utf-8"))

logging.info("Extracting system mentions.")
for doc in testing_corpus:
    doc.system_mentions = mention_extractor.extract_system_mentions(doc)

mention_entity_mapping, antecedent_mapping = experiments.predict(
    testing_corpus,
    extractor,
    perceptron,
    import_helper.import_from_path(args.clusterer)
)

testing_corpus.read_coref_decisions(mention_entity_mapping, antecedent_mapping)

logging.info("Write corpus to file.")
testing_corpus.write_to_file(open(args.output_filename, "w"))

if args.ante:
    logging.info("Write antecedent decisions to file")
    testing_corpus.write_antecedent_decisions_to_file(open(args.ante, "w"))

logging.info("Done.")
