import os
import os.path
import logging
import random
import subprocess
import shlex
import gzip
import re
import functools
import time
import imp
import sys
import json

# workaround needed to fix bug with SCons and the pickle module
del sys.modules['pickle']
sys.modules['pickle'] = imp.load_module('pickle', *imp.find_module('pickle'))
import pickle
import steamroller.scons


##########
# Preamble

# initial variable and environment objects for loading the configuration file
initial_vars = Variables()
initial_vars.AddVariables(
    ("CONFIG_FILE", "Configuration file", "steamroller_config.json"),
)

initial_env = Environment(variables=initial_vars, ENV=os.environ, TARFLAGS="-c -z", TARSUFFIX=".tgz",
                          tools=["default"],
)

# read the JSON-formatted SteamRoller config file
with open(initial_env["CONFIG_FILE"]) as ifd:
    config = json.load(ifd)

config_vars = [
    ("OUTPUT_WIDTH", "Upper limit on how long a debugging line will be before it's truncated", 1000),
    ("VERBOSE", "Whether to print the full commands being executed", False),
    ("MODELS", "", config["MODELS"]),
    ("TASKS", "", config["TASKS"]),
    ("FIGURES", "", config["FIGURES"]),
] + [(k, "", v) for k, v in config["DEFAULTS"].iteritems()]

# actual variable and environment objects
vars = Variables()
vars.AddVariables(
    *config_vars
)

env = Environment(variables=vars, ENV=os.environ, TARFLAGS="-c -z", TARSUFFIX=".tgz",
                  tools=["default", steamroller.scons.generate],
)

# function for width-aware printing of commands
def print_cmd_line(s, target, source, env):
    if len(s) > int(env["OUTPUT_WIDTH"]):
        print(s[:int(env["OUTPUT_WIDTH"]) - 10] + "..." + s[-7:])
    else:
        print(s)

# set the logging level
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO)

# and the command-printing function
env['PRINT_CMD_LINE_FUNC'] = print_cmd_line

# and how we decide if a dependency is out of date
env.Decider("timestamp-newer")


#############
# Experiments

metrics = {}
for task in env["TASKS"]:

    if task.get("DISABLED", False):
        continue
    
    classified_items = []
    train_resource_list = []
    apply_resource_list = []
    model_list = []
    task_name = task["NAME"]
    
    # data is either a single FILE to be randomly split into train/dev/test, or
    # the already split in a particular way
    if "FILE" in task:
        input_file = env.File(task["FILE"])
        count_file, _ = env.GetCount("work/${TASK_NAME}_total.txt.gz", input_file, TASK_NAME=task_name)
    
        for training_size in task.get("TRAINING_SIZES", env["TRAINING_SIZES"]):
            for fold in range(1, task.get("FOLD_COUNT", env["FOLD_COUNT"]) + 1):

                train, test, _ = env.CreateSplit(["work/${TASK_NAME}_train_${FOLD}_${TRAINING_SIZE}_${TESTING_SIZE}.txt.gz",
                                                  "work/${TASK_NAME}_test_${FOLD}_${TRAINING_SIZE}_${TESTING_SIZE}.txt.gz"],
                                                 count_file, FOLD=fold, TRAINING_SIZE=training_size, TASK_NAME=task_name)

                for model in env["MODELS"]:
                    if model.get("DISABLED", False):
                        continue
                    model_name = model["NAME"]
                    train_builder = env["BUILDERS"]["Train%s" % model_name]
                    apply_builder = env["BUILDERS"]["Apply%s" % model_name]
                    model_file, resources = train_builder(env,
                                                          "work/${TASK_NAME}_${MODEL_NAME}_${TRAINING_SIZE}_${FOLD}.model.gz",
                                                          [train, input_file],
                                                          FOLD=fold, TRAINING_SIZE=training_size, TASK_NAME=task_name, MODEL_NAME=model_name,
                                                          GRID_RESOURCES=model.get("GRID_RESOURCES", env.get("GRID_RESOURCES", [])),
                                                          TAG_TYPE=task.get("TAG_TYPE", "attribute")
                                                      )


                    train_resource_list.append(resources)
                    model_list.append(model_file)
                    classified, resources = apply_builder(env,
                                                          "work/${TASK_NAME}_${MODEL_NAME}_${TRAINING_SIZE}_${FOLD}_probabilities.txt.gz",
                                                          [model_file, test, input_file],
                                                          FOLD=fold, TRAINING_SIZE=training_size, TASK_NAME=task_name, MODEL_NAME=model_name,
                                                          GRID_RESOURCES=model.get("grid_resources", []),
                                                          TAG_TYPE=task.get("TAG_TYPE", "attribute")                                                      
                    )
                    apply_resource_list.append(resources)
                    classified_items.append(classified)





    else:
        train_data = env.File(task["TRAIN_FILE"])
        test_data = env.File(task["TEST_FILE"])
        train_indices, _ = env.NoSplit("work/${TASK_NAME}_train.txt.gz", train_data, TASK_NAME=task_name)
        test_indices, _ = env.NoSplit("work/${TASK_NAME}_test.txt.gz", test_data, TASK_NAME=task_name)
        
        for model in env["MODELS"]:
            if model.get("DISABLED", False):
                continue
            model_name = model["NAME"]
            train_builder = env["BUILDERS"]["Train%s" % model_name]
            apply_builder = env["BUILDERS"]["Apply%s" % model_name]
            model_file, resources = train_builder(env,
                                                  "work/${TASK_NAME}_${MODEL_NAME}_1_1.model.gz",
                                                  [train_indices, train_data],
                                                  TASK_NAME=task_name, MODEL_NAME=model_name,
                                                  GRID_RESOURCES=model.get("GRID_RESOURCES", env.get("GRID_RESOURCES", [])),
                                                  TAG_TYPE=task.get("TAG_TYPE", "attribute")
            )

            train_resource_list.append(resources)
            model_list.append(model_file)
            classified, resources = apply_builder(env,
                                                  "work/${TASK_NAME}_${MODEL_NAME}_1_1_probabilities.txt.gz",
                                                  [model_file, test_indices, test_data],
                                                  TASK_NAME=task_name, MODEL_NAME=model_name,
                                                  GRID_RESOURCES=model.get("grid_resources", []),
                                                  TAG_TYPE=task.get("TAG_TYPE", "attribute")                                                      
            )
            apply_resource_list.append(resources)
            classified_items.append(classified)


            

    if len(classified_items) > 0:
        fscores, _ = env.FScore("work/${TASK_NAME}_fscores.txt.gz", classified_items, TASK_NAME=task_name)
        accuracies, _ = env.Accuracy("work/${TASK_NAME}_accuracies.txt.gz", classified_items, TASK_NAME=task_name)        
        train_resources, _ = env.CollateResources("work/${TASK_NAME}_trainresources.txt.gz", train_resource_list, TASK_NAME=task_name, STAGE="train")
        apply_resources, _ = env.CollateResources("work/${TASK_NAME}_applyresources.txt.gz", apply_resource_list, TASK_NAME=task_name, STAGE="apply")
        model_sizes, _ = env.ModelSizes("work/%s_modelsizes.txt.gz" % (task_name), model_list)
        metrics = [accuracies, fscores, train_resources, apply_resources, model_sizes]
        vals = env.CombineCSVs("work/${TASK_NAME}_metrics.txt.gz", metrics, TASK_NAME=task_name)
        for figure in env["FIGURES"]:
            env.Plot("work/${TASK_NAME}_${PLOT_NAME}.png", vals,
                     TASK_NAME=task_name,
                     PLOT_NAME=figure["NAME"],
                     TITLE=task_name,
                     TYPE=figure["TYPE"],
                     X=figure["X"],
                     Y=figure["Y"],
                     XLABEL=figure["XLABEL"],
                     YLABEL=figure["YLABEL"],
                     COLOR=figure["COLOR"],
                     COLOR_LABEL=figure["COLOR_LABEL"]
            )

