import os
import os.path
import logging
import random
import subprocess
import shlex
import gzip
import re
import functools
import time
import imp
import sys
import json
del sys.modules['pickle']

sys.modules['pickle'] = imp.load_module('pickle', *imp.find_module('pickle'))

import pickle
import steamroller.scons


vars = Variables()
vars.AddVariables(
    ("CONFIG_FILE", "Configuration file", "steamroller_config.json"),
)

def print_cmd_line(s, target, source, env):
    if len(s) > int(env["OUTPUT_WIDTH"]):
        print s[:int(env["OUTPUT_WIDTH"]) - 10] + "..." + s[-7:]
    else:
        print s

initial_env = Environment(variables=vars, ENV=os.environ, TARFLAGS="-c -z", TARSUFFIX=".tgz",
                          tools=["default"],
)

with open(initial_env["CONFIG_FILE"]) as ifd:
    config = json.load(ifd)

config_vars = [
    ("OUTPUT_WIDTH", "Upper limit on how long a debugging line will be before it's truncated", 1000),
    ("VERBOSE", "Whether to print the full commands being executed", False),
    ("MODELS", "", config["MODELS"]),
    ("TASKS", "", config["TASKS"]),
    ("FIGURES", "", config["FIGURES"]),
] + [(k, "", v) for k, v in config["DEFAULTS"].iteritems()]

vars = Variables()
vars.AddVariables(
    *config_vars
)

logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO)

env = Environment(variables=vars, ENV=os.environ, TARFLAGS="-c -z", TARSUFFIX=".tgz",
                  tools=["default", steamroller.scons.generate],
)

env['PRINT_CMD_LINE_FUNC'] = print_cmd_line
env.Decider("timestamp-newer")

metrics = {}
for task in env["TASKS"]:
    classified_items = []
    train_resource_list = []
    apply_resource_list = []
    model_list = []
    task_name = task["NAME"]
    input_file = env.File(task["FILE"])
    count_file, _ = env.GetCount("work/${TASK_NAME}_total.txt.gz", input_file, TASK_NAME=task_name)
    
    for training_size in task.get("TRAINING_SIZES", env["TRAINING_SIZES"]):
        for fold in range(1, task.get("FOLD_COUNT", env["FOLD_COUNT"]) + 1):

            train, test, _ = env.CreateSplit(["work/${TASK_NAME}_train_${FOLD}_${TRAINING_SIZE}_${TESTING_SIZE}.txt.gz",
                                              "work/${TASK_NAME}_test_${FOLD}_${TRAINING_SIZE}_${TESTING_SIZE}.txt.gz"],
                                             count_file, FOLD=fold, TRAINING_SIZE=training_size, TASK_NAME=task_name)

            for model in env["MODELS"]:
                model_name = model["NAME"]
                train_builder = env["BUILDERS"]["Train%s" % model_name]
                apply_builder = env["BUILDERS"]["Apply%s" % model_name]
                model_file, resources = train_builder(env,
                                                      "work/${TASK_NAME}_${MODEL_NAME}_${TRAINING_SIZE}_${FOLD}.model.gz",
                                                      [train, input_file],
                                                      FOLD=fold, TRAINING_SIZE=training_size, TASK_NAME=task_name, MODEL_NAME=model_name,
                                                      GRID_RESOURCES=model.get("GRID_RESOURCES", env.get("GRID_RESOURCES", [])),
                                                  )


                train_resource_list.append(resources)
                model_list.append(model_file)
                classified, resources = apply_builder(env,
                                              "work/${TASK_NAME}_${MODEL_NAME}_${TRAINING_SIZE}_${FOLD}_probabilities.txt.gz",
                                              [model_file, test, input_file],
                                              FOLD=fold, TRAINING_SIZE=training_size, TASK_NAME=task_name, MODEL_NAME=model_name,
                                              GRID_RESOURCES=model.get("grid_resources", []),
                )
                apply_resource_list.append(resources)
                classified_items.append(classified)


    if len(classified_items) > 0:
        fscores, _ = env.FScore("work/${TASK_NAME}_fscores.txt.gz", classified_items, TASK_NAME=task_name)
        accuracies, _ = env.Accuracy("work/${TASK_NAME}_accuracies.txt.gz", classified_items, TASK_NAME=task_name)        
        train_resources, _ = env.CollateResources("work/${TASK_NAME}_trainresources.txt.gz", train_resource_list, TASK_NAME=task_name, STAGE="train")
        apply_resources, _ = env.CollateResources("work/${TASK_NAME}_applyresources.txt.gz", apply_resource_list, TASK_NAME=task_name, STAGE="apply")
        model_sizes, _ = env.ModelSizes("work/%s_modelsizes.txt.gz" % (task_name), model_list)
        metrics = [accuracies, fscores, train_resources, apply_resources, model_sizes]
        vals = env.CombineCSVs("work/${TASK_NAME}_metrics.txt.gz", metrics, TASK_NAME=task_name)
        for figure in env["FIGURES"]:
            env.Plot("work/${TASK_NAME}_${PLOT_NAME}.png", vals,
                     TASK_NAME=task_name,
                     PLOT_NAME=figure["NAME"],
                     TITLE=task_name,
                     TYPE=figure["TYPE"],
                     X=figure["X"],
                     Y=figure["Y"],
                     XLABEL=figure["XLABEL"],
                     YLABEL=figure["YLABEL"],
                     COLOR=figure["COLOR"],
                     COLOR_LABEL=figure["COLOR_LABEL"]
            )

