#!/usr/bin/env python
import argparse
from hagelslag.util.Config import Config
from hagelslag.processing.TrackModeler import TrackModeler
from multiprocessing import Pool, Manager
from hagelslag.processing.EnsembleProducts import *
from scipy.ndimage import gaussian_filter
import pandas as pd
import numpy as np
import traceback
from datetime import timedelta
import os


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("config", help="Config filename.")
    parser.add_argument("-t", "--train", action="store_true", help="Train machine learning models.")
    parser.add_argument("-f", "--fore", action="store_true", help="Generate forecasts from machine learning models.")
    parser.add_argument("-e", "--ens", action="store_true", help="Generate ensemble products from ML and raw ensemble variables.")
    parser.add_argument("-g", "--grid", action="store_true", help="Generate forecast grids for machine learning models.")
    args = parser.parse_args()
    required = ["ensemble_name", "train_data_path", "forecast_data_path", "member_files",
                "data_format", "condition_model_names", "condition_model_objs", "condition_input_columns",
                "condition_output_column", "group_col",
                "model_path", "metadata_columns", "data_json_path", "forecast_json_path",
                "load_models", "ensemble_members", "ml_grid_method", "neighbor_radius", "neighbor_sigma",
                "ensemble_consensus_path", "ensemble_variables", "ensemble_variable_thresholds", "ensemble_data_path",
                "size_dis_training_path", "netcdf_path", "watershed_variable","model_map_file"]
        
    config = Config(args.config, required)
    if not hasattr(config, "run_date_format"):
        config.run_date_format = "%Y%m%d-%H%M"
    if any([args.train, args.fore]):
        if not hasattr(config, "sector_train"):
            config.sector_train = None
        
        track_modeler = TrackModeler(config.ensemble_name,
                                     config.train_data_path,
                                     config.forecast_data_path,
                                     config.member_files,
                                     config.start_dates,
                                     config.end_dates,
                                     config.sector_train,
                                     config.model_map_file,
                                     config.group_col)
        if args.train:

            train_models(track_modeler, config)
            
            training_data_percentiles(config.ensemble_members,
                                        config.ensemble_name,
                                        config.watershed_variable,
                                        config.netcdf_path,
                                        config.start_dates,
                                        config.end_dates,
                                        config.size_dis_training_path)
            
        if args.fore:
            forecasts = make_forecasts(track_modeler, config)
            output_forecasts_csv(forecasts, track_modeler, config)
    if args.grid:
        generate_ml_grids(config, mode="forecast")
    if args.ens:
        print("Making ensemble probs")
        make_ensemble_probabilities(config, mode="forecast")
    return


def train_models(track_modeler, config):
    """
    Trains machine learning models to predict size, whether or not the event occurred, and track errors.

    Args:
        track_modeler (hagelslag.TrackModeler): an initialized TrackModeler object
        config: Config object
    """
    track_modeler.load_data(mode="train", format=config.data_format)
    track_modeler.fit_condition_threshold_models(config.condition_model_names,
                                                 config.condition_model_objs,
                                                 config.condition_input_columns,
                                                 config.condition_output_column,
                                                 config.condition_threshold)
    track_modeler.fit_size_distribution_component_models(config.size_distribution_model_names,
                                               config.size_distribution_model_objs,
                                               config.size_distribution_input_columns,
                                               output_columns=config.size_distribution_output_columns)
    track_modeler.save_models(config.model_path)
    return


def make_forecasts(track_modeler, config):
    """
    Generate predictions from all machine learning models.

    Args:
        track_modeler (hagelslag.processing.TrackModeler object): TrackModeler object with configuration information
        config (hagelslag.util.Config object): Configuration information
    Returns:
        dictionary containing forecast values.
    """
    print("Load data")
    track_modeler.load_data(mode="forecast", format=config.data_format)
    if config.load_models:
        print("Load models")
        track_modeler.load_models(config.model_path)
    forecasts = {}
    print("Condition forecasts")
    forecasts["condition"] = track_modeler.predict_condition_models(config.condition_model_names,
                                                                    config.condition_input_columns,
                                                                    config.metadata_columns)

    print("Size Distribution Forecasts")
    forecasts["dist"] = track_modeler.predict_size_distribution_component_models(config.size_distribution_model_names,
                                                                                 config.size_distribution_input_columns,
                                                                                 config.size_distribution_output_columns,
                                                                                 config.metadata_columns,
                                                                                 location=config.size_distribution_loc)
    return forecasts


def output_forecasts(forecasts, track_modeler, config):
    """
    Write forecasts out to GeoJSON files in parallel.

    Args:
        forecasts: dict
            dictionary containing forecast values organized by type
        track_modeler: hagelslag.processing.TrackModeler
            TrackModeler object
        config:
            Config object
    Returns:

    """
    track_modeler.output_forecasts_json_parallel(forecasts,
                                                 config.condition_model_names,
                                                 config.size_model_names,
                                                 config.size_distribution_model_names,
                                                 config.track_model_names,
                                                 config.data_json_path,
                                                 config.forecast_json_path,
                                                 config.num_procs)
    return


def output_forecasts_csv(forecasts, track_modeler, config):
    track_modeler.output_forecasts_csv(forecasts, "forecast", config.forecast_csv_path,
                                       run_date_format=config.run_date_format)
    return


def make_ensemble_probabilities(config, mode="forecast"):
    """
    Generate neighborhood probabilities from each ensemble member in parallel and output to netCDF.

    Args:
        config:
        mode:

    Returns:

    """
    run_dates = pd.DatetimeIndex(start=config.start_dates[mode],
                                 end=config.end_dates[mode],
                                 freq='1D')
    pool = Pool(config.num_procs)
    m = Manager()
    q = m.Queue()
    ensemble_name = config.ensemble_name
    members = config.ensemble_members
    single_step = config.single_step
    variable_thresholds = config.ensemble_variable_thresholds
    track_path = config.forecast_json_path
    all_model_names = config.size_distribution_model_names + ["WRF"] * len(config.ensemble_variables)
    all_var_names = ["dist"] * len(config.size_distribution_model_names) + config.ensemble_variables
    pool.apply_async(merge_member_probabilities, (ensemble_name, all_model_names,
                                                len(members), run_dates.to_pydatetime(), all_var_names,
                                                config.start_hour, config.end_hour,
                                                variable_thresholds, config.neighbor_radius, config.neighbor_sigma,
                                                config.ensemble_consensus_path, q))
    for run_date in run_dates.to_pydatetime():
        if config.ensemble_name.upper() == "NCAR":
            full_path = config.ensemble_data_path + "/{0}/{1}_surrogate_{0}.nc"
            run_date_str = run_date.strftime("%Y%m%d%H")
        else:
            full_path = config.ensemble_data_path + "/{1}/{0}/"
            run_date_str = run_date.strftime("%Y%m%d")
        if all([os.access(full_path.format(run_date_str, m), os.R_OK) for m in config.ensemble_members]):
            print("Starting " + run_date.strftime("%Y%m%d"))
            start_date = run_date + timedelta(hours=config.start_hour)
            end_date = run_date + timedelta(hours=config.end_hour)
            for model_name in config.size_distribution_model_names:
                for member in members:
                    pool.apply_async(member_neighborhood_probability,
                                     (ensemble_name, model_name, member, run_date, "dist",
                                           start_date, end_date, track_path, single_step, variable_thresholds,
                                           config.neighbor_radius, q, config.model_map_file,
                                           config.neighbor_condition_model, config.condition_threshold,
                                           config.num_track_samples, config.neighbor_sigma))
            for model_variable in config.ensemble_variables:
                for member in members:
                    pool.apply_async(member_neighborhood_probability,
                                     args=(ensemble_name, "WRF", member, run_date, model_variable,
                                           start_date, end_date, config.ensemble_data_path, config.single_step,
                                           variable_thresholds, config.neighbor_radius, q, config.model_map_file,
                                           config.neighbor_condition_model, config.condition_threshold,
                                           config.num_track_samples, config.neighbor_sigma))
        else:
            print(run_date.strftime("%Y%m%d") + " not available") 
    pool.close()
    pool.join()


def generate_ml_grids(config, mode="forecast"):
    """
    Creates gridded machine learning model forecasts and writes them to GRIB2 files.

    Args:
        config:
        mode:

    Returns:

    """
    pool = Pool(config.num_procs)
    run_dates = pd.DatetimeIndex(start=config.start_dates[mode],
                                 end=config.end_dates[mode],
                                 freq='1D')
    ml_model_list = config.size_distribution_model_names
    print(ml_model_list)
    ml_var = "hail"
    for run_date in run_dates:
        start_date = run_date + timedelta(hours=config.start_hour)
        end_date = run_date + timedelta(hours=config.end_hour)
        for member in config.ensemble_members:
            args = (config.ensemble_name, ml_model_list, member, run_date, ml_var, start_date, end_date,
                    config.single_step, config.neighbor_condition_model, config.forecast_csv_path, 
                    config.netcdf_path,config.grib_path, config.model_map_file,
                    config.size_dis_training_path,config.watershed_variable)
            pool.apply_async(generate_ml_member_grid, args)
    pool.close()
    pool.join()
    return


def generate_ml_member_grid(ensemble_name, model_names, member, run_date, variable, start_date, end_date,
                            single_step, neighbor_condition_model, forecast_csv_path, netcdf_path,
                            grib_path, map_file, size_distribution_training_path, watershed_obj):
    """


    Args:
        ensemble_name:
        model_names:
        member:
        run_date:
        variable:
        start_date:
        end_date:
        single_step:
        neighbor_condition_model:
        forecast_csv_path:
        netcdf_path:
        grib_path:
        map_file:
        size_distribution_training_path:

    Returns:

    """
    try:
        if exists(forecast_csv_path + "hail_forecasts_{0}_{1}_{2}.csv".format(ensemble_name, member,
                                                                              run_date.strftime("%Y%m%d-%H%M"))):
            ep = EnsembleMemberProduct(ensemble_name, model_names[0], member, run_date, variable,
                                        start_date, end_date, None, single_step, 
                                        size_distribution_training_path, 
                                        watershed_obj, map_file=map_file,
                                        condition_model_name=neighbor_condition_model)
            for model_name in model_names:
                ep.model_name = model_name
                ep.load_forecast_csv_data(forecast_csv_path)
                ep.load_forecast_netcdf_data(netcdf_path)
                try:
                    ep.quantile_match()
                    #ep.load_data(num_samples=num_samples, percentiles=ml_grid_percentiles)
                    grib_objects = ep.encode_grib2_data()
                    if not os.access(grib_path + run_date.strftime("/%Y%m%d/"), os.R_OK):
                        try:
                            os.mkdir(grib_path + run_date.strftime("/%Y%m%d/"))
                        except OSError:
                            pass
                    ep.write_grib2_files(grib_objects, grib_path + run_date.strftime("%Y%m%d/"))
                except:
                    print("No {0} {1} netcdf files found".format(run_date.strftime("%Y%m%d"),member))
        else:
            print("No model runs on " + run_date.strftime("%Y%m%d") + " " +  member)
    except Exception as e:
        print(traceback.format_exc())
        raise e
    return


def member_neighborhood_probability(ensemble_name, model_name, member, run_date, variable, start_date,
                                    end_date, path, single_step, thresholds, radii, queue, map_file,
                                    condition_model_name, condition_threshold, num_samples, sigmas):
    try:
        print(member, ensemble_name, model_name, variable, run_date)
        emp = EnsembleMemberProduct(ensemble_name, model_name, member, run_date, variable, start_date, end_date,
                                    path, single_step, map_file=map_file, condition_model_name=condition_model_name,
                                    condition_threshold=condition_threshold)
        emp.load_data(num_samples=num_samples, percentiles=None)
        #for radius in radii:
        #    for threshold in thresholds:
        #        print("Hourly {0} {1} {2} {3}".format(radius, threshold, member, run_date))
        #        neighbor_prob = emp.neighborhood_probability(threshold, radius)
        #        smooth_neighbor_prob = np.zeros(neighbor_prob.shape)
        #        for sigma in sigmas:
        #            for t in range(neighbor_prob.shape[0]):
        #                smooth_neighbor_prob[t] = gaussian_filter(neighbor_prob[t], sigma)
        #            queue.put(["hourly", run_date, model_name, variable, threshold, radius,
        #                       sigma, member, smooth_neighbor_prob])
        #            smooth_neighbor_prob[:] = 0
        for radius in radii:
            for threshold in thresholds[variable]:
                print("Period {0} {1} {2} {3}".format(radius, threshold, member, run_date))
                period_prob = emp.period_max_neighborhood_probability(threshold, radius)
                for sigma in sigmas:
                    smooth_period_prob = gaussian_filter(period_prob, sigma)
                    queue.put(["period", run_date, model_name, variable, threshold, radius,
                               sigma, smooth_period_prob])

    except Exception as e:
        print(traceback.format_exc())
        raise e
    return


def merge_member_probabilities(ensemble_name, model_names, num_members, run_dates, variables, start_hour, end_hour,
                               thresholds, radii, sigmas, out_path, queue):
    try:
        merged_data = dict()
        for run_date in run_dates:
            for model_name, variable in zip(model_names, variables):
                for threshold in thresholds[variable]:
                    for radius in radii:
                        for sigma in sigmas:
                            #merged_data[("hourly", run_date, model_name, variable, threshold, radius, sigma)] = [None, 0]
                            merged_data[("period", run_date, model_name, variable, threshold, radius, sigma)] = [None, 0]
        while len(merged_data.keys()) > 0:
            if not queue.empty():
                output = queue.get()
                out_key = tuple(output[0:7])
                run_date = out_key[1]
                model_name = out_key[2]
                variable = out_key[3]
                threshold = out_key[4]
                radius = out_key[5]
                sigma = out_key[6]
                start_date = run_date + timedelta(hours=start_hour)
                end_date = run_date + timedelta(hours=end_hour)
                times = pd.DatetimeIndex(start=start_date, end=end_date, freq="1H")
                if merged_data[out_key][0] is None:
                    merged_data[out_key][:] = [output[-1], 1]
                else:
                    merged_data[out_key][0] += output[-1]
                    merged_data[out_key][1] += 1
                print("Added to ", out_key, "Size: ", merged_data[out_key][1])
                if merged_data[out_key][1] == num_members:
                    if not exists(out_path + "/" + run_date.strftime("%Y%m%d")):
                        os.mkdir(out_path + run_date.strftime("%Y%m%d"))
                    out_filename = out_path + "{3}/{0}_{1}_{2}_consensus_{3}.nc".format(ensemble_name,
                                                                                        model_name.replace(" ", "-"),
                                                                                        variable,
                                                                                        run_date.strftime("%Y%m%d"))
                    ens_mean = merged_data[out_key][0] / float(num_members)
                    if out_key[0] == "hourly":
                        ec = EnsembleConsensus(ens_mean,
                                                "neighbor_prob_r_{0:d}_s_{1:d}".format(radius, sigma),
                                                ensemble_name,
                                                run_date, variable + "_{0:0.2f}".format(threshold),
                                                start_date, end_date, "") 
                    else:
                        ec = EnsembleConsensus(ens_mean,
                                                "neighbor_prob_{0:02d}-hour_r_{1:d}_s_{2:d}".format(times.shape[0],
                                                                                                    radius,
                                                                                                    sigma),
                                                ensemble_name,
                                                run_date, variable + "_{0:0.2f}".format(threshold),
                                                start_date, end_date, "") 
                    if not exists(out_filename):
                        out_file = ec.init_file(out_filename)
                    else:
                        out_file = Dataset(out_filename, "a")
                    print("Writing to " + out_filename, out_key)
                    ec.write_to_file(out_file)
                    del merged_data[out_key]
                    out_file.close()
    except Exception as e:
        print(traceback.format_exc())
        raise e
    return 0 


def ensemble_run_neighborhood_probabilities(ensemble_name, model_name, members, run_date, variable, start_date,
                                            end_date, path, single_step, thresholds, radii, sigmas, out_path,
                                            num_procs, map_file, condition_model_name, condition_threshold,
                                            num_samples):
    pool = Pool(num_procs)
    m = Manager()
    q = m.Queue()
    for member in members:
        pool.apply_async(member_neighborhood_probability, args=(ensemble_name, model_name, member, run_date, variable,
                                                        start_date, end_date, path, single_step, thresholds,
                                                        radii, q, map_file, condition_model_name,
                                                        condition_threshold, num_samples, sigmas))
    merge_member_probabilities(ensemble_name, model_name, len(members), run_date, variable,
                               start_date, end_date, thresholds, radii, sigmas, out_path, q)
    pool.close()
    pool.join()
    return


def training_data_percentiles(members, ensemble_name, watershed_obj, training_data_path, 
                                start_date, end_date, csv_outpath):

    """
    Creates watershed object distribution of sizes for a given set of training netcdf patches

    Args:
            members (list of str): List of ensemble members in netcdf file
            ensemble_name (str): Name of ensemble member in netcdf file
            watershed_obj (str): Watershed object used in config files
            training_data_path (str): Path to netcdf files
            csv_outpath (str): Path to output csv 
            start_date (str): Start date of training period, in form of Y-m-d
            end_date (str): End date of training period, in form of Y-m-d
            csv_outpath(str): Path to file storage 
    Returns:
            obj_per_vals (list): Distribution of watershed object values over training data
            
    """


    obj_data = dict()
    training_date = pd.DatetimeIndex(start=start_date["train"], end=end_date["train"], freq='1D').strftime("%Y%m%d")

    print('Creating Distribution of Watershed Object Sizes')
    for m in members:
        for date in training_date:
            #print("Model: {0}, Date: {1}".format(m,date))
            data_path = training_data_path+'{0}_{1}-0000_{2}_model_patches.nc'.format(ensemble_name,date,m)
            if exists(data_path):
                dataset  = Dataset(data_path)
                mask_index  = np.where(dataset.variables["masks"][:]==1)
                obj_data["{0}_model".format(m)] = dataset.variables["{0}_curr".format(watershed_obj)][:]
                obj_data["masked_model_{0}_{1}".format(m,date)] = obj_data["{0}_model".format(m)][mask_index]

    obj_values = np.concatenate([obj_data["masked_model_{0}_{1}".format(mem,t)]\
                                for mem in members for t in training_date if "masked_model_{0}_{1}".format(mem,t)\
                                in obj_data.keys()], axis=0)

    print('Number of {0} Objects Found: {1}'.format(watershed_obj,obj_values.shape[0]))

    percentiles = np.linspace(0.1, 99.9, 100)
    obj_percent_val = np.percentile(obj_values, percentiles)
    dataframe_values = {'Obj_Values':obj_percent_val, 'Percentiles':percentiles}
    obj_value_percentile = pd.DataFrame(data=dataframe_values)
    if csv_outpath:
        csv_path = csv_outpath+'{0}_{1}_Size_Distribution.csv'.format(ensemble_name,watershed_obj)
        obj_value_percentile.to_csv(csv_path, index=False)

        print('Output csv: {0}'.format(csv_path))
    return


if __name__ == "__main__":
    main()
