#!/usr/bin/env python
from netCDF4 import Dataset
from glob import glob
from multiprocessing import Pool
from copy import deepcopy
from scipy.ndimage import gaussian_filter
from hagelslag.util.make_proj_grids import read_arps_map_file, read_ncar_map_file,make_proj_grids
from hagelslag.util.Config import Config
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.basemap import Basemap
import os
import pandas as pd
import numpy as np
import argparse
import pickle

try: 
    from ncepgrib2 import Grib2Encode, dump
    grib_support = True
except ImportError("ncepgrib2 not available"):
    grib_support = False   



def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("config", help="Configuration file") 
    parser.add_argument("-t", "--train", action="store_true", help="Train calibration models.")
    parser.add_argument("-f", "--fore", action="store_true", help="Generate forecasts from calibration models.")
    parser.add_argument("-g", "--grib_out", action="store_true", help="generate grib2 files.")
    parser.add_argument("-n", "--netcdf_out", action="store_true", help="generate netcdf files.")
    parser.add_argument("-p", "--plot_out", action="store_true", help="Plot calibrated forecasts.")
    
    args = parser.parse_args()
    required = ["calibration_model_names", "calibration_model_objs", "ensemble_name", 
                "forecast_model_names","train_data_path", "forecast_data_path", 
                "target_data_path", "target_data_names", "model_path", 
                "run_date_format","size_threshold","sector","out_path","map_file"] 
    config = Config(args.config, required)
    
    stride=1
    smoothing=14
    
    if not hasattr(config, "run_date_format"):
        config.run_date_format = "%y%m%d"
    
    run_date = config.start_dates['forecast']
    run_date_str = run_date.strftime(config.run_date_format)
    
    if config.map_file[-3:] == "map":                                  
        proj_dict, grid_dict = read_arps_map_file(config.map_file)
    else:                                               
        proj_dict, grid_dict = read_ncar_map_file(config.map_file)   
    
    if args.train:
        trained_models = train_calibration(config)
        saving_cali_models(trained_models,config)
    
    if args.fore: 
        print()
        print('Forecast run date(s): {0}'.format(run_date_str))
        
        lsr_models = load_cali_models(config)
        
        print('Creating calibrated forecasts')    
        
        print()
        print('Full 24 hour forecast')
        
        start_time = config.start_hour
        end_time = config.end_hour  
                
        lsr_NEP =  create_calibration_forecasts(lsr_models,
                                            start_time,end_time,
                                            stride,
                                            config)
        for size in config.size_threshold:
            for cali_model_name in config.calibration_model_names:
                if args.plot_out:    
                    forecast_plotting(lsr_NEP[size][cali_model_name],
                                    proj_dict,grid_dict,start_time,
                                    end_time,stride,size,smoothing
                                    config)

                if args.grib_out:
                    output_grib2(
                            lsr_NEP[size][cali_model_name],
                            proj_dict,grid_dict,start_time,end_time,
                            stride,size,run_date,config.target_data_names,
                            smoothing,config)
                
                if args.netcdf_out:
                    output_netcdf(
                            lsr_NEP[size][cali_model_name],
                            proj_dict,grid_dict,start_time,end_time,
                            stride,size,run_date,config.target_data_names,
                            smoothing,config)
    
        print()
        print('Four-hour forecasts') 

        for hour in range(17,22):
            start_time = hour
            end_time = hour+4
                    
            lsr_NEP_hour = create_calibration_forecasts(lsr_models,
                                                    start_time,end_time,
                                                    stride,
                                                    config)
            for size in config.size_threshold:
                for cali_model_name in config.calibration_model_names:
                    if args.plot_out:    
                        forecast_plotting(lsr_NEP_hour[size][cali_model_name],
                                    proj_dict,grid_dict,start_time,
                                    end_time,stride,size,smoothing,config)
        
                    if args.grib_out:
                        output_grib2(
                            lsr_NEP[size][cali_model_name],
                            proj_dict,grid_dict,start_time,end_time,
                            stride,size,run_date,config.target_data_names,
                            smoothing,config)
                
                    if args.netcdf_out:
                        output_netcdf(
                            lsr_NEP[size][cali_model_name],
                            proj_dict,grid_dict,start_time,end_time,
                            stride,size,run_date,config.target_data_names,
                            smoothing,config)
    return


def train_calibration(config):
    """
    Loads Neighboorhood Ensemble Probability (NEP) forecasts.
    Trains machine learning models to calibrate NEP forecasts towards 
    a chosen target dataset. Currently only capatable with 
    Local Storm Reports (LSRs) 

    """
    
    pool = Pool(config.num_procs)
    
    run_dates = pd.DatetimeIndex(start=config.start_dates["train"],
                                 end=config.end_dates["train"],
                                 freq='1D').strftime(config.run_date_format)
    
    lsr_calib_models = {}

    print('Loading Data')

    for size in config.size_threshold:
        lsr_calib_models[size] = {}
        train_files, lsr_files = [], []
        for date in run_dates: 
            train_data_files = config.train_data_path+ \
                    "20{3}/netcdf/{0}_Hail_{1}_NEP_{2}_{3}_Hours_{4}-{5}.nc".format(
                        config.ensemble_name, config.forecast_model_names,size,date,
                        config.start_hour,config.end_hour)
                
            if config.sector:
                lsr_data_files = config.target_data_path+'lsr/{0}_{1}_{2}_mask.nc'.format(
                            date,size,config.sector)
                
            else:
                lsr_data_files = config.target_data_path+'lsr/{0}_{1}_mask.nc'.format(date,size)
            
            if os.path.exists(train_data_files) & os.path.exists(lsr_data_files):
                train_files.append(train_data_files)
                lsr_files.append(lsr_data_files)
            
            else:
                continue
            
        t_data = [Dataset(x).variables["Data"][:] for x in train_files]
        train_data = np.array(t_data).flatten()

        
        l_data = [Dataset(x).variables["24_Hour_All_12z_12z"][:] for x in lsr_files]        
        lsr_data = np.array(l_data).flatten()
        print("Training size data: {0}".format(size))
        
        for ind,model_name in enumerate(config.calibration_model_names):
            lsr_calib_models[size][model_name] = deepcopy(config.calibration_model_objs[ind])
            lsr_calib_models[size][model_name].fit(train_data,lsr_data)
    
    pool.close()
    pool.join()
    
    return[lsr_calib_models]



def saving_cali_models(config):

    """
    Save calibration machine learning models to pickle files.

    """
    
    print('Saving Models')
    if config.size_threshold:
        for size, calibration_model in target_dataset_model.items():
            for model_name, model_objs in calibration_model.items():
                out_cali_filename = config.model_path + \
                                    '{0}_{1}_{2}mm_calibration.pkl'.format(
                                    model_name.replace(" ", "-"),
                                    config.target_data_names,size)
                    
                print('Writing out: {0}'.format(out_cali_filename)) 

                with open(out_cali_filename, "wb") as pickle_file:
                        pickle.dump(model_objs, 
                        pickle_file,
                        pickle.HIGHEST_PROTOCOL)
                    
    return



def load_cali_models(config):
    
    """
    Load calibration models from pickle files.
    """

    print()
    print("Load models")
    
    lsr_cali_model = {}
    
    lsr_calibration_model_files = sorted(glob(config.model_path + "*lsr*_calibration.pkl"))

    for size in config.size_threshold:
        
        lsr_cali_model[size] = {}

        for model_name in config.calibration_model_names:
            if len(lsr_calibration_model_files) > 0:
                for lsr_file in lsr_calibration_model_files:
                    with open(lsr_file,"rb") as lsr_cmf:
                        lsr_cali_model[size][model_name] = pickle.load(lsr_cmf)

    return lsr_cali_model


def create_calibration_forecasts(lsr_models,
                                start_hour,end_hour,
                                stride,config):
    
    """
    Generate calibrated Neighborhood Ensemble Probability (NEP) predictions. 
    
    Returns:
        A dictionary containing local storm report (lsr) calibrated forecast NEP values. 
    """
    lon = None
    lat = None


    run_dates = pd.DatetimeIndex(start=config.start_dates["forecast"],
                                 end=config.end_dates["forecast"],
                                 freq='1D').strftime(config.run_date_format)

    lsr_cali_fore = {}
    
    if config.size_threshold:
        for size in config.size_threshold:
            train_file = []
            lsr_cali_fore[size] = {}
            
            for date in run_dates: 
                train_data_files = config.forecast_data_path+\
                    "20{3}/netcdf/{0}_Hail_{1}_NEP_{2}_{3}_Hours_{4}-{5}.nc".format(
                    config.ensemble_name, config.forecast_model_names,size,date,
                    start_hour,end_hour)

                if os.path.exists(train_data_files): 
                    train_file.append(train_data_files)
            
            t_data = [Dataset(x).variables["Data"][:] for x in train_file]
            train_data=np.array(t_data).flatten()
            
            for model_name in config.calibration_model_names:
                data_shape = (len(run_dates),np.shape(t_data[0])[0],np.shape(t_data[0])[1])
                
                lsr_cali_fore[size][model_name]=\
                lsr_models[size][model_name].transform(train_data).reshape(data_shape)

    return lsr_cali_fore

def forecast_plotting(forecast,proj_dict,grid_dict,
                    start_hour,end_hour,stride,size,
                    smoothing,config):
    
    """
    Plot calibrated predictions. 
    
    Args: 
        forecasts (dict): generated calibrated NEP forecasts
        proj_dict (dict): projection information of forecasts
        grid_dict (dict): gridded information of forecasts
        start_hour (int): Beginning hour of chosen forecast period
        end_hour (int): Ending hour of chosen forecast period
        stride (int): Smoohing factor
        size (int): hail size threshold
    """

    date_outpath = config.out_path+'{0}/'.format(config.start_dates['forecast'].strftime("%Y%m%d"))
    
    if not os.path.exists(date_outpath):
        os.makedirs(date_outpath)
    
    filtered_forecast = gaussian_filter(forecast[0][::stride,::stride],smoothing)
    
    map_data = make_proj_grids(proj_dict,grid_dict)
    lons = map_data["lon"]
    lats = map_data["lat"]
    
    plt.figure(figsize=(9,6))
    
    m=Basemap(projection='lcc', 
                area_thresh=10000.,
                resolution="l",
                lon_0=proj_dict["lon_0"],
                lat_0=proj_dict["lat_0"],
                lat_1=proj_dict["lat_1"],
                lat_2=proj_dict["lat_0"],
                llcrnrlon=lons[0,0],
                llcrnrlat=lats[0,0],
                urcrnrlon=lons[-1,-1],
                urcrnrlat=lats[-1,-1])
    
    m.drawcoastlines()
    m.drawstates()
    m.drawcountries()
    m.fillcontinents(color='gray',alpha=0.2)
    
    x1,y1 = m(lons,lats)
    x1, y1 = x1[::stride,::stride], y1[::stride,::stride]
    
    cmap = matplotlib.colors.ListedColormap(['#DBC6BD','#AD8877','#FCEA8D', 'gold','#F76E67','#F2372E',
                                                        '#F984F9','#F740F7','#AE7ADD','#964ADB',
                                                        '#99CCFF', '#99CCFF','#3399FF'])
                        
    levels = [0.01,0.05,0.15, 0.225, 0.30, 0.375, 0.45, 0.525, 0.60, 0.70, 0.8, 0.9, 1.0]
                
    plt.contourf(x1,y1,filtered_forecast,cmap=cmap,levels=levels)
    cbar = plt.colorbar(orientation="horizontal", shrink=0.7, fraction=0.05, pad=0.02)
    cbar.set_ticks([0.01,0.05,0.15, 0.30, 0.45, 0.60, 0.80, 1.0])
    cbar.set_ticklabels([1,5,15,30,45,60,80,100])
    
    plt.title("{0} LSR Calibrated >{1}mm NEP \n Valid {2}, Hours: {3}-{4}UTC".format(
            config.ensemble_name,
            size,
            config.start_dates["forecast"].strftime("%d %b %Y"),
            start_hour%24,end_hour%24),
            fontweight="bold",
            fontsize=16)
    
    filename = date_outpath + "{0}_hail_{1}_cali_NEP_{2}mm_{3}_Hours_{4}-{5}.png".format(
                            config.ensemble_name,
                            config.target_data_names,
                            size,
                            config.start_dates["forecast"].strftime("%y%m%d"),
                            start_hour,end_hour)
    
    
    #plt.savefig(filename,bbox_inches="tight", dpi=300)
    plt.show()
    print("Writing to " + filename)
    plt.close()
    
    return
    
def output_grib2(data,proj_dict,grid_dict,start_hour,end_hour,
                stride,size,run_date,target_dataset,smoothing,config):    
    
    """
    Writes out grib2 files for given Neighborhood Probability numpy array data. 

    Args: 
        data (list): Generated calibrated NEP forecasts
        proj_dict (dict): Projection information of forecasts
        grid_dict (dict): Grid information of forecasts
        start_hour (int): Beginning hour of chosen forecast period
        end_hour (int): Ending hour of chosen forecast period
        stride (int): Smoohing factor
        size (int): Hail size threshold
        run_date (dataframe): Valid date for forecast
        target_dataset(str): Name of the dataset being calibrated towards
    """
    
    date_outpath = config.out_path+'{0}/'.format(run_date.strftime("%Y%m%d"))
    
    if not os.path.exists(date_outpath):
        os.makedirs(date_outpath)

    lscale = 1e6
    grib_id_start = [7, 0, 14, 14, 2]
   
    filtered_forecast = gaussian_filter(data[0][::stride,::stride],smoothing)
    gdsinfo = np.array([0, np.product(data.shape[-2:]), 0, 0, 30], dtype=np.int32)
    
    lon_0 = proj_dict["lon_0"]
    sw_lon = grid_dict["sw_lon"]
        
    if lon_0 < 0:
        lon_0 += 360
    if sw_lon < 0:
        sw_lon += 360

    gdtmp1 = [1, 0, proj_dict['a'], 0, float(proj_dict['a']), 0, float(proj_dict['b']),
            data.shape[-1], data.shape[-2], grid_dict["sw_lat"] * lscale,
            sw_lon * lscale, 0, proj_dict["lat_0"] * lscale,
            lon_0 * lscale,
            grid_dict["dx"] * 1e3 * stride, grid_dict["dy"] * 1e3 * stride, 0b00000000, 0b01000000,
            proj_dict["lat_1"] * lscale,
            proj_dict["lat_2"] * lscale, -90 * lscale, 0]
    pdtmp1 = np.array([1, 31, 4, 0, 31, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1], dtype=np.int32)
    drtmp1 = np.array([0, 0, 4, 8, 0], dtype=np.int32)
    time_list = list(run_date.utctimetuple()[0:6])
    grib_objects = Grib2Encode(0, np.array(grib_id_start + time_list + [2, 1], dtype=np.int32))
    grib_objects.addgrid(gdsinfo, gdtmp1)
    pdtmp1[8] = start_hour
    pdtmp1[-2] = 0
    grib_objects.addfield(1, pdtmp1, 0, drtmp1, data)
    grib_objects.end()
    filename = date_outpath + "{0}_hail_{1}_cali_NEP_{2}mm_{3}_Hours_{4}-{5}.grib2".format(config.ensemble_name,
                                                         target_dataset,
                                                         size,
                                                         run_date.strftime("%y%m%d"),
                                                         start_hour,end_hour)
    print("Writing to " + filename)
        
    grib_file = open(filename, "wb")
    grib_file.write(grib_objects.msg)
    grib_file.close()

    return

    
def output_netcdf(data,proj_dict,grid_dict,start_hour,end_hour,
                stride,size,run_date,target_dataset,smoothing,config):    
    
    """
    Writes out netCDF4 files for given Neighborhood Ensemble Probability (NEP) forecasts. 

    Args: 
        data (list): Generated calibrated NEP forecasts
        proj_dict (dict): Projection information of forecasts
        grid_dict (dict): Grid information of forecasts
        start_hour (int): Beginning hour of chosen forecast period
        end_hour (int): Ending hour of chosen forecast period
        stride (int): Smoohing factor
        size (int): Hail size threshold
        run_date (dataframe): Valid date for forecast
        target_dataset(str): Name of the dataset being calibrated towards
    """
    
    y1 = np.arange(0,np.shape(data[0])[0]*grid_dict['dy'],grid_dict['dy'])
    x1 = np.arange(0,np.shape(data[0])[1]*grid_dict['dx'],grid_dict['dx'])
             
    x1, y1 = np.meshgrid(x1, y1)
    x1, y1 = x1[::stride,::stride], y1[::stride,::stride]

    filtered_forecast = gaussian_filter(data[0][::stride,::stride],smoothing)
    
    date_outpath = config.out_path+'{0}/'.format(run_date.strftime("%Y%m%d"))
    
    if not os.path.exists(date_outpath):
        os.makedirs(date_outpath)
    
    filename = "{0}_hail_{1}_cali_NEP_{2}mm_{3}_Hours_{4}-{5}.nc".format(config.ensemble_name,
                                                         target_dataset,
                                                         size,
                                                         run_date.strftime("%y%m%d"),
                                                         start_hour,end_hour)
    out_filename = date_outpath+filename
    
    
    out_file = Dataset(out_filename, "w")
    out_file.createDimension("x", filtered_forecast.shape[0])
    out_file.createDimension("y", filtered_forecast.shape[1])
    out_file.createVariable("Longitude", "f4", ("x", "y"))
    out_file.createVariable("Latitude", "f4",("x", "y"))
    out_file.createVariable("Data", "f4", ("x", "y"))
    out_file.variables["Longitude"][:,:] = x1
    out_file.variables["Latitude"][:,:] = y1
    out_file.variables["Data"][:,:] = filtered_forecast
    out_file.projection = proj_dict["proj"]
    out_file.lon_0 = proj_dict["lon_0"]
    out_file.lat_0 = proj_dict["lat_0"]
    out_file.lat_1 = proj_dict["lat_1"]
    out_file.lat_2 = proj_dict["lat_2"]
    out_file.close()
    
    print("Writing to " + out_filename)

    return
if __name__ == "__main__":
    main()
