import os
import os.path
import logging
import random
import subprocess
import shlex
import gzip
import re
import functools
import time
import imp
import sys
import json

# workaround needed to fix bug with SCons and the pickle module
del sys.modules['pickle']
sys.modules['pickle'] = imp.load_module('pickle', *imp.find_module('pickle'))
import pickle
import steamroller.scons

##########
# Preamble

# initial variable and environment objects for loading the configuration file
initial_vars = Variables()
initial_vars.AddVariables(
    ("CONFIG_FILE", "Configuration file", "steamroller_config.json"),
)

initial_env = Environment(variables=initial_vars, ENV=os.environ, TARFLAGS="-c -z", TARSUFFIX=".tgz",
                          tools=["default"],
)

# read the JSON-formatted SteamRoller config file
with open(initial_env["CONFIG_FILE"]) as ifd:
    config = json.loads("\n".join([l.strip() for l in ifd if not re.match(r"^\s*\/\/.*$", l)]))

config["DEFAULTS"] = config.get("DEFAULTS", {})
    
# set the logging level
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.getLevelName(config["DEFAULTS"].get("LOGGING", "WARNING")))

config_vars = [
    ("OUTPUT_WIDTH", "Upper limit on how long a debugging line will be before it's truncated", 1000),
    ("VERBOSE", "Whether to print the full commands being executed", False),
    ("WORK_PATH", "Sub-directory where files will be generated", "work"),
    ] + [(n, "", config[n]) for n in ["DATA_SETS",
                                      "FEATURE_EXTRACTORS",
                                      "FEATURE_TRANSFORMERS",
                                      "MODEL_TYPES",
                                      "MEASUREMENTS",
                                      "VISUALIZATIONS",
                                      "EXPERIMENTS"]] + [(k, "", v) for k, v in config["DEFAULTS"].items()]


# actual variable and environment objects
vars = Variables()
vars.AddVariables(
    *config_vars
)

env = Environment(variables=vars, ENV=os.environ, TARFLAGS="-c -z", TARSUFFIX=".tgz",
                  tools=["default", steamroller.scons.generate],
)

# function for width-aware printing of commands
def print_cmd_line(s, target, source, env):
    if len(s) > int(env["OUTPUT_WIDTH"]):
        print(s[:int(float(env["OUTPUT_WIDTH"]) / 2) - 2] + "..." + s[-int(float(env["OUTPUT_WIDTH"]) / 2) + 1:])
    else:
        print(s)

# and the command-printing function
env['PRINT_CMD_LINE_FUNC'] = print_cmd_line

# and how we decide if a dependency is out of date
env.Decider("timestamp-newer")

#############
# Experiments
metrics = {}
for experiment_name, experiment in env["EXPERIMENTS"].items():
    if experiment.get("DISABLED", False):
        continue
    measurements = []
    for data_set_name, data_set in experiment.get("DATA_SETS", {}).items():

        data_set_fields = env["DATA_SETS"][data_set_name]

        features = {}

        for extractor_name, extraction_spec in experiment.get("FEATURE_EXTRACTORS", {}).items():            
            extractor_args = extraction_spec["ARGUMENTS"]
            for fname in ([data_set_fields["FILE"]] if "FILE" in data_set_fields else [data_set_fields[x] for x in ["TRAIN", "DEV", "TEST"]]):
                features[extractor_name] = env["BUILDERS"][extraction_spec["METHOD"]](env, source=fname, DATA_SET_NAME=data_set_name, FEATURE_EXTRACTOR_NAME=extractor_name, **extractor_args, **data_set)

        # for feature_transformers ...
        
        # if folds aren't specified, default to one fold
        for fold in range(1, experiment.get("FOLDS", 1) + 1):
            
            # if down-sampling isn't specified, default to none
            for sizes in experiment.get("DOWN_SAMPLE", [None]):

                if "FILE" in data_set_fields:
                    train, dev, test = [idx for idx in env.Split(data_set_fields["FILE"], NAME=data_set_name, SIZES=sizes)]
                else:
                    raise Exception("The definition of data set {} must either specify a FILE, or TRAIN, DEV, and TEST files".format(data_set_name))

                for extractor_name, feature_file in features.items():
                
                    for model_name, model in experiment["MODEL_TYPES"].items():

                        method = model["METHOD"]
                        method_args = model["ARGUMENTS"]
                        train_builder = env["BUILDERS"]["Train %s" % method]
                        apply_builder = env["BUILDERS"]["Apply %s" % method]
                        model_file = train_builder(env,
                                                   source=[train, feature_file],
                                                   FOLD=fold, SIZE=sizes, DATA_SET_NAME=data_set_name, FEATURE_EXTRACTOR_NAME=extractor_name, MODEL_NAME=model_name,
                                                   GRID_RESOURCES=model.get("GRID_RESOURCES", env.get("GRID_RESOURCES", [])),
                                                   **extractor_args, **data_set, **method_args
                        )

                        probability_file = apply_builder(env,
                                                         source=[model_file, test, feature_file],
                                                         FOLD=fold, SIZE=sizes, DATA_SET_NAME=data_set_name, FEATURE_EXTRACTOR_NAME=extractor_name, MODEL_NAME=model_name,
                                                         GRID_RESOURCES=model.get("GRID_RESOURCES", env.get("GRID_RESOURCES", [])),
                                                         **extractor_args, **data_set, **method_args
                        )

                        for measurement_name, measurement in experiment.get("MEASUREMENTS", {}).items():
                            caliper = env["BUILDERS"]["%s" % measurement_name]
                            measurement_file = caliper(env,
                                                       source=probability_file,
                                                       FOLD=fold, SIZE=sizes, DATA_SET_NAME=data_set_name, FEATURE_EXTRACTOR_NAME=extractor_name, MODEL_NAME=model_name,
                                                       MEASUREMENT_NAME=measurement_name,
                                                       GRID_RESOURCES=model.get("GRID_RESOURCES", env.get("GRID_RESOURCES", [])),
                                                       **extractor_args, **data_set, **method_args
                            )
                            measurements.append(measurement_file)
                        
    for visualization_name, visualization in experiment.get("VISUALIZATIONS", {}).items():
        artist = env["BUILDERS"]["%s" % visualization_name]
        visualization_file = artist(env,
                                    source=measurements,
                                    EXPERIMENT_NAME=experiment_name, VISUALIZATION_NAME=visualization_name,
                                    GRID_RESOURCES=model.get("GRID_RESOURCES", env.get("GRID_RESOURCES", []))
        )
