"""prepare_data.

Prepare dfr-browser data files from MALLET outputs.

The original version of this script by Andrew Goldstone can be found at
https://github.com/agoldst/dfr-browser/blob/master/bin/prepare-data.

This version has been made compatible with Python 3 and given some
further commenting.
"""

from typing import Dict, List, Tuple
import re
import os
import json
import zipfile as zf
import gzip
from collections import defaultdict

def check_files(dir: str) -> None:
    """Check the presence and format of all required files are present in a directory.

    Args:
        dir (str): The directory to check.

    Returns:
        None
    """
    files = os.listdir(dir)

    print("Checking info.json...")
    if "info.json" not in files:
        print('`info.json` not found. Use "prepare-data info-stub" to create one.')
    else:
        with open(f"{dir}/info.json") as f:
            try:
                json.load(f)
                print("info.json: ok")
            except Exception as e:
                print(f"Unable to parse info.json as JSON. json error:\n{e.message}")

    print("Checking meta.csv.zip...")
    if "meta.csv.zip" in files:
        with zf.ZipFile(f"{dir}/meta.csv.zip") as z:
            with z.open(z.infolist()[0]) as f:
                meta = f.readlines()
                print("meta.csv.zip: ok")
    else:
        print('No meta.csv.zip found. Use "prepare-data convert-citations" on DfR citations.tsv files.')

    print("Checking topic_scaled.csv...")
    if "topic_scaled.csv" not in files:
        print('No topic_scaled.csv found. This file is required only for the "scaled" overview.')
    else:
        print("topic.scaled.csv: ok")

    print("Checking tw.json...")
    if "tw.json" not in files:
        print('No tw.json found. Use "prepare-data convert-tw" on a topic-word matrix CSV or "prepare-data convert-keys" on a CSV listing top words and weights.')
    else:
        with open(f"{dir}/tw.json") as f:
            try:
                tw = json.load(f)
                if "alpha" not in tw:
                    raise Exception("alpha values not present")

                if "tw" not in tw:
                    raise Exception("tw field missing")

                tws = tw["tw"]
                for t in list(range(len(tws))):
                    if "words" not in tws[t]:
                        raise Exception(f"Words missing for topic {str(t + 1)}")
                    if "weights" not in tws[t]:
                        raise Exception(f"Weights missing for topic {str(t + 1)}")
                print("tw.json: ok")
            except Exception as e:
                print(f"Problem with tw.json: {e.message}")

    print("Checking dt.json.zip...")
    if "dt.json.zip" not in files:
        print('No dt.json.zip found. Use "prepare-data convert-dt" on a document-topic matrix CSV.')
    else:
        with zf.ZipFile(f"{dir}/dt.json.zip") as z:
            with z.open(z.infolist()[0]) as f:
                try:
                    dt = json.load(f)
                    if "i" not in dt or "p" not in dt or "x" not in dt:
                        raise Exception("i, p, or x member missing")
                    if len(meta) > 0:
                        if max(dt["i"]) != len(meta) - 1:
                            raise Exception("doc-topics / metadata mismatch")
                    print("dt.json.zip: ok")
                except Exception as e:
                    print(f"Problem with doc-topics JSON: {e.message}")


def info_stub(filepath: str) -> None:
    """Write an info.json stub to a file.

    Args:
        filepath (str): The file to write to.
    Returns:
        None
    """
    with open(filepath, "w") as f:
        json.dump({
            "title": "",
            "meta_info": r'<h2></h2>',
            "VIS": {"overview_words": 15}
            },
            fp=f, indent=4)
        print("Created stub file in " + f.name)


def convert_citations(files: List[str], matchfile: str, output_file: str) -> None:
    """Write an info.json stub to a file.

    Converts JSTOR DfR citations.tsv files ... into zipped metadata.
    The matchfile is a file with one document ID per line. If
    supplied, these IDs are matched against the first column of the
    input metadata, and the output contains only matching documents
    (in the order given in in the matchfile).

    Args:
        files (List[str]): A list of filepaths.
        matchfile (str): The file to match.
        output_file (str): The file to write to.

    Returns:
        None
    """
    ll = []
    in_ids = defaultdict(int)
    i = 0
    for file in files:
        with open(file) as meta:
            meta.readline() # header
            for line in meta:
                if not line.strip():
                    continue  # Ignore lines that contain only whitespace.
                fields = [s.replace('"', '""') for s in line.strip().split("\t")]
                ll.append(f'"{fields[0]}","{",".join(fields[2:9])}"')
                in_ids[fields[0]] = i
                i += 1

    if matchfile is not None:
        with open(matchfile) as f:
            out_ids = [s.strip() for s in f]
        meta_out = "\n".join([ll[in_ids[key]] for key in out_ids])
    else:
        meta_out = "\n".join(ll)

    with zf.ZipFile(output_file, "w") as z:
        z.writestr("meta.csv", meta_out)

    print(f"Wrote metadata to {output_file}")


def convert_tw(tw_file: str, output_file: str, vocab_file: str, param_file: str, n: int) -> None:
    """Convert a topic-word matrix to a JSON file.
    Args:
        tw_file (str): The topic-word matrix file (a headerless CSV file).
        output_file (str): The output file.
        vocab_file (str): The vocabulary file listing one line per column of tw_file.
        param_file (str): The `params.txt` file written by dfrtopics v0.2. If this is
            missing, topic alpha values will be missing.
        n (int): The number of top words per topic (50 by default).

    Returns:
        None
    """
    tw = []
    with open(vocab_file) as f:
        vocab = [s.strip() for s in f]

    with open(tw_file) as f:
        for line in f:
            weights = [int(x) for x in line.strip().split(",")]
            tw.append(transform_topic_weights(weights, vocab, n))

    if param_file is None:
        print("Warning: no parameters file, so alpha_k will be set to zero")
        alpha = [0] * len(tw)
    else:
        with open(param_file) as f:
            p = " ".join([l.strip() for l in f.readlines()])
            m = re.search(r"alpha = c\(([0-9., ]+)\)", p)
            alpha = [float(a) for a in m.group(1).split(", ")]

    write_tw(alpha, tw, output_file)


def transform_topic_weights(weights: List[int], vocab: List[str], n: int) -> List[Dict[str, float]]:
    """Transform topic weights to a JSON-compatible format.

    Args:
        weights (List[int]): The topic weights.
        vocab (List[str]): The vocabulary.
        n (int): The number of topics.

    Returns:
        List[Dict[str, float]]: The transformed topic weights.
    """
    words = list(range(len(weights)))
    words.sort(key=lambda i: -weights[i])
    return({
        "words": [vocab[w] for w in words[:n]],
        "weights": [weights[w] for w in words[:n]]
    })

def write_tw(alpha: list, tw: dict, output_file: str) -> None:
    """Write a topic-word matrix to a JSON file.

    Args:
        alpha (list): The alpha values.
        tw (dict): The topic-word matrix.
        output_file (str): The output file.

    Returns:
        None
    """
    twj = {
        "alpha": alpha,
        "tw": tw
    }
    with open(output_file, "w") as f:
        json.dump(twj, f)

    print(f"Wrote topic-words information to {f.name}")

def convert_keys(keys_file: str, output_file: str) -> None:
    """Use a keys file to write topic-word information to a JSON file.

    Args:
        keys_file (str): The keys file: a CSV file with:
            `topic,alpha,word,weight` columns (from dfrtopics v0.1) or
            `topic,word,weight columns` (from dfrtopics v0.2)
        output_file (str): The output file.

    Returns:
        None
    """
    with open(keys_file) as f:
        header = f.readline()
        style = len(header.strip().split(","))
        if style == 3:
            print("New-style top topic words: alpha_k missing and left at 0")
            alphas, tw = keys_newstyle(f)
        elif style == 4:
            alphas, tw = keys_oldstyle(f)
        else:
            raise Exception("Unknown top topic words format: expect 3 or 4 cols")

    write_tw(alphas, tw, output_file)

def keys_oldstyle(file: str) -> Tuple:
    """Convert an old style keys file.

    Args:
        file (str): The keys file.

    Returns:
        Tuple: The alpha values and the topic-word matrix.
    """
    tw = []
    last_topic = 1
    words = []
    weights = []
    alphas = []
    for line in file:
        topic, alpha, word, weight = line.strip().split(",")
        topic = int(topic)
        if topic != last_topic:
            tw.append({"words": words, "weights": weights})
            words = []
            weights = []
            alphas.append(float(alpha))
        words.append(word)
        weights.append(int(weight))
        last_topic = topic
    tw.append({"words": words, "weights": weights})
    alphas.append(float(alpha))

    return(alphas, tw)

def keys_newstyle(file):
    """Convert a new style keys file.

    Args:
        file (str): The keys file.

    Returns:
        Tuple: The alpha values and the topic-word matrix.
    """
    tw = []
    last_topic = 1
    words = []
    weights = []
    for line in file:
        topic, word, weight = line.strip().split(",")
        topic = int(topic)
        if topic != last_topic:
            tw.append({"words": words, "weights": weights})
            words = []
            weights = []
        words.append(word)
        weights.append(int(weight))
        last_topic = topic
    tw.append({"words": words, "weights": weights})

    # alpha left at zero
    return([0] * len(tw), tw)


def convert_dt(dt_file: str, output_file: str) -> None:
    """Write zipped document-topic information to a JSON file.

    The ordinary matrix is converted to the required sparse format.

    Args:
        dt_file (str): A headerless CSV with document-word weights.
        output_file (str): The output file.

    Returns:
        None
    """
    with open(dt_file) as f:
        d1 = f.readline().strip()
        dt = [[int(x)] for x in d1.split(",")]
        K = len(dt)
        for line in f:
            xs = line.strip().split(",")
            for t in list(range(K)):
                dt[t].append(int(xs[t]))

    write_dt(transform_dt(dt), output_file)

def transform_dt(dt: list) -> dict:
    """Transform document-topic matrix.

    Args:
        dt (list): The document-topic matrix.

    Returns:
        dict: The transformed document-topic matrix.
    """
    D = len(dt[0])
    p = [0]
    i = []
    x = []
    p_cur = 0
    for topic_docs in dt:
        for d in list(range(D)):
            if topic_docs[d] != 0:
                i.append(d)
                x.append(topic_docs[d])
                p_cur += 1
        p.append(p_cur)

    return({ "i": i, "p": p, "x": x })

def write_dt(dtj: dict, output_file: str) -> None:
    """Write a document-topic matrix to a JSON file.

    Args:
        dtj (dict): The document-topic matrix.
        output_file (str): The output file.

    Returns:
        None
    """
    with zf.ZipFile(output_file, "w") as z:
        z.writestr("dt.json", json.dumps(dtj))

    print(f"Wrote sparse doc-topics to {output_file}")

def convert_state(state_file: str, tw_file: str = "tw.json", dt_file: str = "dt.json.zip", n: int = 50) -> None:
    """Use the MALLET sampling state to write both topic words and document topics.

    Args:
        state_file (str): The gzipped state file from mallet train-topics.
        tw_file (str): The topic-words file.
        dt_file (str): The document-topic file.
        n (int): The number of topics.

    Returns:
        None
    """
    with gzip.open(state_file, "rb") as f:
        f.readline()
        alpha = list(map(float, f.readline().decode().strip().split(" ")[2:]))
        beta = f.readline().decode().strip().split(" ")[2]
        print(f"beta value, not saved in a file: {beta}")

        # A dict of topic numbers where each topic number is a dict of {typeindex: weight}
        tw = defaultdict(lambda : defaultdict(int))
        # A dict of typeindex: word
        vocab = dict()
        # A list of dicts where each dict is of type {topicnumber: int}
        dt = []
        # A dict of type {topicnumber: int}
        cur_dt = defaultdict(int)

        last_doc = 0    # assume we start at doc 0
        K = 0

        # Iterate through the state file
        for line in f:
            # Split the line and ensure int fields are ints
            doc,source,pos,typeindex,word,topic = line.strip().split()
            doc = int(doc)
            typeindex = int(typeindex)
            topic = int(topic)
            # Set the topic number
            if topic > K:
                K = topic
            # If we're at a new document, save the current document and rest cur_dt
            if last_doc != doc:
                dt.append(cur_dt)
                cur_dt = defaultdict(int)
            # Increment the topic number for cur_dt
            cur_dt[topic] += 1
            # Increment the type index for the topic in tw
            tw[topic][typeindex] += 1
            # Add the word and typeindex to the vocab if it is not already there
            if typeindex not in vocab:
                vocab[typeindex] = word.decode()
            # Set last_doc as the current doc id
            last_doc = doc

        # K is max(topic), but we want it to be number of topics:
        K = K + 1
        # final doc: after end of for loop
        if len(cur_dt) > 0:
            dt.append(cur_dt)

        # Create a list of dicts for each topic where each dict has list of weights and words
        topic_dicts = []
        for _, values in tw.items():
            weights = []
            words = []
            for typeindex, weight in values.items():
                weights.append(weight)
                words.append(vocab[typeindex])
            topic_dicts.append({"weights": weights, "words": words})

    transformed_tw = [transform_topic_weights(topic_dicts[t]["weights"], vocab, n) for t in list(range(K))]
    write_tw(alpha, transformed_tw, tw_file)

    transformed_dt = transform_dt([[d[t] for d in dt] for t in list(range(K))])
    write_dt(transformed_dt, dt_file)


def help():
    """Print help documentation."""
    print(__doc__)


def key_arg(args: list, key: str) -> str:
    """Parse keyword arguments.

    Args:
        args (list): The arguments.
        key (str): The keyword flag.

    Returns:
        str: A result.
    """
    res = None
    try:
        i = args.index(key)
        res = rest[i + 1]
        del rest[i:i + 2]
    except:
        pass

    return(res)


if __name__=="__main__":
    import sys
    if len(sys.argv) == 1:
        help()
    else:
        cmd = sys.argv[1]
        if len(sys.argv) == 2:
            rest = []
            out = None
        else:
            rest = sys.argv[2:]
            out = key_arg(rest, "-o")
        if cmd == "help":
            help()
        elif cmd == "check":
            if len(rest) > 0:
                check_files(rest[0])
            else:
                check_files(".")
        elif cmd == "info-stub":
            if out is None:
                info_stub("info.json")
            else:
                info_stub(out)
        elif cmd == "convert-citations":
            if out is None:
                out = "meta.csv.zip"
            matchfile = key_arg(rest, "--ids")
            convert_citations(rest, matchfile, out)
        elif cmd == "convert-tw":
            if out is None:
                out = "tw.json"
            vocabf = key_arg(rest, "--vocab")
            if vocabf is None:
                raise Exception("A vocabulary file must be supplied with --vocab")
            paramf = key_arg(rest, "--param")
            n = key_arg(rest, "-n")
            if n is None:
                n = 50
            convert_tw(rest[0], out, vocabf, paramf, int(n))
        elif cmd == "convert-keys":
            if out is None:
                out = "tw.json"
            convert_keys(rest[0], out)
        elif cmd == "convert-dt":
            if out is None:
                out = "dt.json.zip"
            convert_dt(rest[0], out)
        elif cmd == "convert-state":
            tw = key_arg(rest, "--tw")
            if tw is None:
                tw = "tw.json"
            dt = key_arg(rest, "--dt")
            if dt is None:
                dt = "dt.json.zip"
            n = key_arg(rest, "-n")
            if n is None:
                n = 50
            convert_state(rest[0], tw, dt, n)
        else:
            help()
