#!/usr/bin/env python

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

import os
import re
import sys
import itertools
import pprint
import argparse

from textworld.text_utils import extract_vocab_from_gamefiles
from textworld.generator import Game
from textworld.generator.text_grammar import Grammar, GrammarOptions


def build_parser():
    DESCRIPTION = "Extract information from of a list of TextWorld games."
    general_parser = argparse.ArgumentParser(add_help=False)

    general_group = general_parser.add_argument_group('General settings')
    general_group.add_argument("-f", "--force", action="store_true")
    general_group.add_argument("--merge", action="store_true",
                               help="Merge extracted information to existing output file.")

    verbosity_group = general_group.add_mutually_exclusive_group()
    verbosity_group.add_argument("-q", "--quiet", action="store_true")
    verbosity_group.add_argument("-v", "--verbose", action="store_true")

    parser = argparse.ArgumentParser(parents=[general_parser], description=DESCRIPTION)
    subparsers = parser.add_subparsers(dest="subcommand",
                                       help='Type of information to extract.')

    vocab_parser = subparsers.add_parser("vocab", parents=[general_parser],
                                         help='Extract vocabulary.')
    vocab_parser.add_argument("games", metavar="game", nargs="*",
                              help="List of TextWorld games (.ulx|.z8|.json).")
    vocab_parser.add_argument("--output", default="vocab.txt",
                              help="Output file containing all words (.txt). Default: %(default)s")
    vocab_parser.add_argument("--theme",
                              help="Provide a text grammar theme from which to extract words.")

    entities_parser = subparsers.add_parser("entities", parents=[general_parser],
                                            help='Extract entity names.')
    entities_parser.add_argument("games", metavar="game", nargs="+",
                                 help="List of TextWorld games (.ulx|.z8|.json).")
    entities_parser.add_argument("--output", default="entities.txt",
                                 help="Output file containing all entity names (.txt). Default: %(default)s")

    walkthrough_parser = subparsers.add_parser("walkthroughs", parents=[general_parser],
                                               help='Extract walkthroughs.')
    walkthrough_parser.add_argument("games", metavar="game", nargs="+",
                                    help="List of TextWorld games (.ulx|.json).")
    walkthrough_parser.add_argument("--output", default="walkthroughs.txt",
                                    help="Output file containing all walkthroughs (.txt). Default: %(default)s")

    commands_parser = subparsers.add_parser("commands", parents=[general_parser],
                                            help='Extract all possible commands.')
    commands_parser.add_argument("games", metavar="game", nargs="+",
                                 help="List of TextWorld games (.ulx|.json).")
    commands_parser.add_argument("--output", default="commands.txt",
                                 help="Output file containing all commands (.txt). Default: %(default)s")

    return parser


def main():
    parser = build_parser()
    args = parser.parse_args()

    if not args.subcommand:
        parser.error("A subcommand is required.")

    if os.path.isfile(args.output) and not (args.force or args.merge):
        msg = "{} already exists. Either use --merge or --force to overwrite."
        print(msg.format(args.output))
        sys.exit(1)

    if args.subcommand == "entities":
        units = ["entity", "entities"]
        infos = set()
        for gamefile in args.games:
            jsonfile = os.path.splitext(gamefile)[0] + ".json"
            game = Game.load(jsonfile)
            infos |= set(game.objects_names)

        infos = sorted(infos)

    elif args.subcommand == "vocab":
        units = ["word", "words"]
        infos = extract_vocab_from_gamefiles(args.games)

        if args.theme:
            # Extract words from text grammar theme.
            grammar = Grammar(GrammarOptions(theme=args.theme))
            grammar_words = grammar.get_vocabulary()
            text = " ".join(set(grammar_words))
            # Strip out all non-alphabetic characters
            text = re.sub(r"[^a-z0-9\-_ ']", " ", text.lower())
            infos |= set(word.strip("-'_") for word in text.split())

        infos = sorted(infos)

    elif args.subcommand == "walkthroughs":
        units = ["walkthrough", "walkthroughs"]
        walkthroughs = []
        for gamefile in args.games:
            game = Game.load(os.path.splitext(gamefile)[0] + ".json")
            if game.walkthrough:
                walkthroughs.append(" > ".join(game.walkthrough))
            else:
                if not args.quiet:
                    print("Walkthrough is missing for '{}'.".format(gamefile))

        infos = walkthroughs

    elif args.subcommand == "commands":
        units = ["command", "commands"]
        infos = set()
        for gamefile in args.games:
            game = Game.load(os.path.splitext(gamefile)[0] + ".json")

            # Build objects_per_types mapping.
            objects_per_type = {t: [] for t in game.objects_types}
            for name, t in game.objects_names_and_types:
                objects_per_type[t].append(name)
                for ancestor in game.kb.types.get_ancestors(t):
                    objects_per_type[ancestor].append(name)

            commands = []
            regex = re.compile(r"{{({})}}".format("|".join(game.objects_types)))
            for template in game.command_templates:
                placeholders = regex.findall(template)

                iter_candidates = itertools.product(*[objects_per_type[t] for t in placeholders])
                for candidates in iter_candidates:
                    commands.append(template)
                    for placeholder, candidate in zip(placeholders, candidates):
                        commands[-1] = commands[-1].replace("{" + placeholder + "}", candidate, 1)

            infos |= set(commands)

        infos = sorted(infos)

    unit = units[1] if len(infos) > 1 else units[0]
    if args.verbose:
        print("{} found:".format(unit.title()))
        pprint.pprint(infos)

    print("-> Found {} {}.".format(len(infos), unit))

    if args.merge and os.path.isfile(args.output):
        with open(args.output) as f:
            old_infos = set(f.read().split("\n"))

        new_infos = set(infos) - old_infos
        unit = units[1] if len(new_infos) > 1 else units[0]
        if args.verbose and len(new_infos) > 0:
            print("\nNew {}:".format(unit))
            pprint.pprint(sorted(new_infos))

        infos = sorted(set(infos) | old_infos)
        unit = units[1] if len(infos) > 1 else units[0]
        print("-> Saved {} {} ({} new).".format(len(infos), unit, len(new_infos)))

    with open(args.output, "w") as f:
        f.write("\n".join(infos))


if __name__ == "__main__":
    main()
