#!/usr/bin/env python

import click
from transcription_compare.levenshtein_distance_calculator import UKKLevenshteinDistanceCalculator
from transcription_compare.tokenizer import CharacterTokenizer, WordTokenizer
from transcription_compare.local_optimizer.digit_util import DigitUtil
from transcription_compare.local_optimizer.local_cer_optimizer import LocalCerOptimizer
import os
from transcription_compare.utils.plot_util import plot_alignment_result_only_distance
from transcription_compare.results import MultiResult
import json
import time


@click.command()
@click.option('--reference', '-r', type=str, help='source string')
@click.option('--output', '-o',  type=str, multiple=True, help='target string')
@click.option('--reference_file', '-R', type=click.File('r'), help='source file path')
@click.option('--output_file', '-O', multiple=True, type=click.File('r'), help='target file path')
@click.option('--alignment', '-a', default=False, is_flag=True,
              help='Do you want to see the alignment result? True/False')
@click.option('--error_type', '-e', default='CER', type=click.Choice(['CER', 'WER']))
@click.option('--output_format', '-j', default='TABLE',
              type=click.Choice(['JSON', 'TABLE', 'HTML']))
@click.option('--brackets', '-b', multiple=True, default='()',
              help='example input: -b () -b [] -b {} -b <>')
@click.option('--to_lower', '-l', default=False, is_flag=True, help='Do you want to lower all the words? True/False')
@click.option('--remove_punctuation', '-p', default=False, is_flag=True,
              help='Do you want to remove all the punctuation? True/False')
@click.option('--use_alternative_spelling', '-u', default=False, is_flag=True,
              help='Ignore error made by alternative spelling? True/False')
@click.option('--to_save_plot', '-P', default=False, is_flag=True, help='Do you want to see the windows? True/False')
@click.option('--to_edit_step', '-s', type=int, default=500, help='Please enter the step')
@click.option('--to_edit_width', '-w', type=int, default=500, help='Please enter the width')
@click.option('--file_path', '-f', help='Please enter the path where you would like to save the files')


def main(reference, output, reference_file, output_file, alignment, error_type, output_format,brackets, to_lower,
         remove_punctuation, to_save_plot, to_edit_step, to_edit_width, file_path, use_alternative_spelling):

    """
    Transcription compare tool provided by VoiceGain
    """
    # start = time.clock()
    if file_path is not None:
        if os.path.isdir(file_path) is False:
            raise ValueError("No such file or directory")
    if reference is not None:
        reference = reference
        # reference_file_name = "reference"
    elif reference_file is not None:
        # with open(reference_file, 'r') as file1:
        # print(reference_file)
        # print(str(reference_file)[5:10])

        # reference_file_name = os.path.basename(reference_file.name)
        # print('split', str(reference_file).split(" ")[1].split("\\")[-1].split(".")[0])
        reference = reference_file.read()

    else:
        raise ValueError("One of --reference and --reference_file must be specified")

    total_outputs = len(output) + len(output_file)

    if total_outputs == 0:
        raise ValueError("One of --output and --output_file must be specified")

    if total_outputs == 1:
        is_multiple = False
        # file_name = reference_file_name + '_' + "output"
    else:
        is_multiple = True
        # file_name = reference_file_name + '_' + 'multi_output'

    if error_type == "CER":
        calculator = UKKLevenshteinDistanceCalculator(
            tokenizer=CharacterTokenizer(),
            get_alignment_result=alignment
        )

    else:
        if is_multiple:

            local_optimizers = [LocalCerOptimizer()]
            # we are not using digit util because the result have different length.
        else:
            local_optimizers = [LocalCerOptimizer(), DigitUtil()]
            # local_optimizers = [LocalCerOptimizer()]

        calculator = UKKLevenshteinDistanceCalculator(
            tokenizer=WordTokenizer(),
            get_alignment_result=alignment,
            local_optimizers=local_optimizers,
            is_master=True
        )


    output_all = dict()  # (output identifier -> output string)
    for (M, o) in enumerate(output):
        output_all["string_output_{}".format(M)] = o

    for o in output_file:
        output_path = o.read()
        output_path_name = os.path.basename(o.name)
        # print('output_path_name', output_path_name)
        output_all[output_path_name] = output_path

    # elapsed = (time.clock() - start)
    # print("Time used before getting result:", elapsed)

    brackets_list = []
    for b in brackets:
        brackets_list.append(b)

    output_results = dict()  # (output_identifier -> output_string)
    for (key, value) in output_all.items():
        print("Start to process {}".format(key))
        # print('key', key)#  need more time
        output_results[key] = calculator.get_distance(reference, value, brackets_list=brackets_list,  to_lower=to_lower,
                                                      remove_punctuation=remove_punctuation, use_alternative_spelling=use_alternative_spelling)

    # elapsed = (time.clock() - start)
    # print("Time used when getting result:", elapsed)
    # print(output_results)

    calculator_local = UKKLevenshteinDistanceCalculator(
        tokenizer=CharacterTokenizer(),
        get_alignment_result=False
    )
    result = MultiResult(output_results, calculator_local)

    if output_format == 'TABLE':
        # if is_multiple:
        #    raise ValueError("TABLE output doesn't support multi-way comparison")
        click.echo(result.result())
        click.echo(result.multi_alignment_result.__str__())
        # click.echo(result.result_2())

    elif output_format == 'JSON':
        # click.echo(result.to_json())
        s = result.to_json()
        with open('json_output.json', 'w', encoding='utf-8') as f:
            json.dump(s, f, ensure_ascii=False, indent=4)

    elif output_format == 'HTML':
        s = result.to_html()
        gen_html = "transcription-compare.html"

        if file_path is not None:
            f = open(os.path.join(file_path, gen_html), 'w')
        else:
            f = open(gen_html, 'w')
        f.write(s)
        f.close()

    if to_save_plot:
        alignment_results = []
        sub_plot_name = []
        for (k, v) in output_results.items():
            alignment_results.append(v.alignment_result)
            sub_plot_name.append(k)
        plot_alignment_result_only_distance(alignment_results, to_edit_width, to_edit_step, sub_plot_name, file_path)
        # plot_alignment_result(alignment_results, to_edit_width, to_edit_step, sub_plot_name, file_path)
        # plot_alignment_result_density(alignment_results, to_edit_width, to_edit_step, sub_plot_name,
        #                               file_path)
    # elapsed = (time.clock() - start)
    # print("Time used:", elapsed)


if __name__ == '__main__':
    main()

