#!python

import argparse
import logging
import os
import codecs
from collections import defaultdict

from tilse.data import timelines
from tilse.evaluation import metrictests, rouge


def get_scores(metric_desc, pred_tl, groundtruth, evaluator):
    if metric == "concat":
        return evaluator.evaluate_concat(pred_tl, groundtruth)
    elif metric == "agreement":
        return evaluator.evaluate_agreement(pred_tl, groundtruth)
    elif metric == "align_date_costs":
        return evaluator.evaluate_align_date_costs(pred_tl, groundtruth)
    elif metric == "align_date_content_costs":
        return evaluator.evaluate_align_date_content_costs(pred_tl, groundtruth)
    elif metric == "align_date_content_costs_many_to_one":
        return evaluator.evaluate_align_date_content_costs_many_to_one(pred_tl, groundtruth)


logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s %(levelname)s %(''message)s')

parser = argparse.ArgumentParser(description="Evaluate timelines")

parser.add_argument('-p', dest="predicted", type=str, help='Predicted timelines',
                               required=True)

parser.add_argument('-r', dest="reference", type=str, help='Reference timelines',
                               required=True, nargs="*")

parser.add_argument('-m', dest="metric", type=str, help='Metric to use',
                               required=True)

args = parser.parse_args()

predicted = timelines.Timeline.from_file(codecs.open(args.predicted, "r", "utf-8", "replace"))
reference = args.reference
metric = args.metric

temp_ref_tls = []

for filename in args.reference:
        temp_ref_tls.append(
            timelines.Timeline.from_file(codecs.open(filename, "r", "utf-8", "replace"))
        )

reference_timelines = timelines.GroundTruth(temp_ref_tls)

evaluator = rouge.TimelineRougeEvaluator(measures=["rouge_1", "rouge_2"])

scores = get_scores(metric, predicted, reference_timelines, evaluator)

print(scores)
