#!/usr/bin/env python


from __future__ import print_function
import argparse
import codecs
import logging
import pickle
import sys


from cort.preprocessing import pipeline
from cort.core import mention_extractor
from cort.coreference import cost_functions
from cort.coreference import experiments
from cort.coreference import features
from cort.coreference import instance_extractors
from cort.util import import_helper


__author__ = 'smartschat'

logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s %(levelname)s %(''message)s')


def parse_args():
    parser = argparse.ArgumentParser(description='Predict coreference '
                                                 'relations.')
    parser.add_argument('-in',
                        required=True,
                        dest='input_filename',
                        help='The raw text input files.',
                        nargs="*")
    parser.add_argument('-model',
                        required=True,
                        dest='model',
                        help='The model learned via cort-train.')
    parser.add_argument('-suffix',
                        dest='suffix',
                        default="out",
                        help='Sufix for output files. Defaults to "out".')
    parser.add_argument('-extractor',
                        dest='extractor',
                        required=True,
                        help='The function to extract instances.')
    parser.add_argument('-perceptron',
                        dest='perceptron',
                        required=True,
                        help='The perceptron to use.')
    parser.add_argument('-clusterer',
                        dest='clusterer',
                        required=True,
                        help='The clusterer to use.')
    parser.add_argument('-features',
                        dest='features',
                        help='The file containing the list of features. If not'
                             'provided, defaults to a standard set of'
                             'features.')
    parser.add_argument('-corenlp',
                        dest='corenlp',
                        required=True,
                        help='Location of CoreNLP jars.')

    return parser.parse_args()


logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s %(levelname)s %(''message)s')

if sys.version_info[0] == 2:
    logging.warning("You are running cort under Python 2. cort is much more "
                    "efficient under Python 3.3+.")

args = parse_args()

if args.features:
    mention_features, pairwise_features = import_helper.get_features(
        args.features)
else:
    mention_features = [
        features.fine_type,
        features.gender,
        features.number,
        features.sem_class,
        features.deprel,
        features.head_ner,
        features.length,
        features.head,
        features.first,
        features.last,
        features.preceding_token,
        features.next_token,
        features.governor,
        features.ancestry
    ]

    pairwise_features = [
        features.exact_match,
        features.head_match,
        features.same_speaker,
        features.alias,
        features.sentence_distance,
        features.embedding,
        features.modifier,
        features.tokens_contained,
        features.head_contained,
        features.token_distance
    ]


logging.info("Loading model.")
priors, weights = pickle.load(open(args.model, "rb"))

perceptron = import_helper.import_from_path(args.perceptron)(
    priors=priors,
    weights=weights,
    cost_scaling=0
)

extractor = instance_extractors.InstanceExtractor(
    import_helper.import_from_path(args.extractor),
    mention_features,
    pairwise_features,
    cost_functions.null_cost,
    perceptron.get_labels()
)

logging.info("Reading in and preprocessing data.")
p = pipeline.Pipeline(args.corenlp)

testing_corpus = p.run_on_docs("corpus", args.input_filename)

logging.info("Extracting system mentions.")
for doc in testing_corpus:
    doc.system_mentions = mention_extractor.extract_system_mentions(doc)

mention_entity_mapping, antecedent_mapping = experiments.predict(
    testing_corpus,
    extractor,
    perceptron,
    import_helper.import_from_path(args.clusterer)
)

testing_corpus.read_coref_decisions(mention_entity_mapping, antecedent_mapping)

logging.info("Write output to file.")

for doc in testing_corpus:
    output = doc.to_simple_output()
    my_file = codecs.open(doc.identifier + "." + args.suffix, "w", "utf-8")
    my_file.write(output)
    my_file.close()

logging.info("Done.")
