#!/usr/bin/env python
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import argparse
import os
from os.path import exists
from hagelslag.util.make_proj_grids import read_arps_map_file, read_ncar_map_file, make_proj_grids 
from hagelslag.data.HailForecastGrid import HailForecastGrid
from mpl_toolkits.basemap import Basemap
from matplotlib.colors import LinearSegmentedColormap
from netCDF4 import Dataset, date2num

try: 
    from ncepgrib2 import Grib2Encode, dump
    grib_support = True
except ImportError("ncepgrib2 not available"):
    grib_support = False   


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-r", "--run", help="Date of the model run.")
    parser.add_argument("-s", "--start", help="Start Date of the model run time steps.")
    parser.add_argument("-e", "--end", help="End Date of the model run time steps.")
    parser.add_argument("-n", "--ens", help="Name of the ensemble.")
    parser.add_argument("-a", "--map_file", help="Map file")      
    parser.add_argument("-m", "--model", help="Name of the machine learning model")
    parser.add_argument("-b", "--members", help="Comma-separated list of members.")
    parser.add_argument("-v", "--var", default="hail", help="Variable being plotted.")
    parser.add_argument("-g", "--grib", type=int, default=1, help="GRIB message number.")
    parser.add_argument("-p", "--path", help="Path to GRIB files")
    parser.add_argument("-o", "--out", help="Path where figures are saved.")
    parser.add_argument("-i", "--grib_out", action="store_true", help="Generate grib2 files")
    parser.add_argument("-t", "--image_out", action="store_true", help="Generate image files.")
    parser.add_argument("-c", "--netcdf_out", action="store_true", help="Generate netcdf files.")
    parser.add_argument("-f", "--subplot_out", action="store_true", help="Generate subplot png files.")

    args = parser.parse_args()
    run_date = pd.Timestamp(args.run).to_pydatetime()
    map_file = args.map_file
    start_date = pd.Timestamp(args.start).to_pydatetime()
    end_date = pd.Timestamp(args.end).to_pydatetime()
    ensemble_members = args.members.split(",")
    neighbor_radius = 42.0
    neighbor_smoothing = 14 #number of gridpoints 
    thresholds = np.array([25, 50])
    stride = 1
   
    if any([args.grib_out, args.image_out, args.netcdf_out, args.subplot_out]):
        print("Loading data")
        forecast_grid = load_hail_forecasts(run_date, start_date, end_date, args.ens, args.model,
                                            ensemble_members, args.var, args.grib, args.path)
        if args.grib_out:
            print("Output grib2 files")
            output_grib2_files(forecast_grid,0,24,run_date,map_file,args.out,
                                neighbor_radius,neighbor_smoothing,thresholds,stride)
            for time in range(5,10):
                output_grib2_files(forecast_grid,time,(time+4),run_date,map_file,args.out,
                                   neighbor_radius,neighbor_smoothing,thresholds,stride)                             
        if args.netcdf_out:
            print("Output netcdf files")
            output_netcdf_files(forecast_grid,0,24,args.out,neighbor_radius,
                                neighbor_smoothing,thresholds,stride)
            for time in range(5,10):
                output_netcdf_files(forecast_grid,time,(time+4),args.out,neighbor_radius,
                                     neighbor_smoothing,thresholds,stride)

    
        if any([args.image_out,args.subplot_out]):
            print("Get basemaps")
            bmap_sub = get_sub_basemap(forecast_grid, stride)
            bmap_full = get_full_basemap(forecast_grid)

            if args.image_out:
                print("Daily plots: neighborhood probability, ensemble_max")
                plot_period_ensemble_max(0,24,forecast_grid.start_date.strftime("%d %b %Y"),
                                   forecast_grid,bmap_full, args.out)
                plot_period_neighborhood_probability(forecast_grid, bmap_sub, neighbor_radius,
                                                     neighbor_smoothing, thresholds,
                                                     stride, args.out,0,24)

                print("Hourly plots: neighborhood probability, ensemble_max")
                for time in range(5,10):
                    plot_period_ensemble_max(time,(time+4),forecast_grid.start_date.strftime("%d %b %Y"),
                                            forecast_grid, bmap_full, args.out)

                    plot_period_neighborhood_probability(forecast_grid, bmap_sub, neighbor_radius,
                                                     neighbor_smoothing,thresholds,stride,
                                                     args.out,time,(time+4))
            if args.subplot_out:
                plot_subplots(0,24,forecast_grid.start_date.strftime("%d %b %Y"),
                                forecast_grid,bmap_full, args.out,"ens_max",ensemble_members)
    return


def load_hail_forecasts(run_date, start_date, end_date, ensemble_name, ml_model, members, variable,
                        message_number, path):
    forecast_grid = HailForecastGrid(run_date, start_date, end_date, ensemble_name, ml_model, members,
                                     variable, message_number, path)
    forecast_grid.load_data()
    return forecast_grid

def output_grib2_files(forecast_grid,start_time,end_time,run_date,map_,out_path,
                       radius,smoothing,thresholds,stride):
    """
    Calculates the neighborhood probability and ensemble maximum hail
    over a specified time period, and outputs grib files.

    
    Args:
            forecast_grid: data from input forecast grib file
            start_time: beginning hour for output grib files
            end_time: end hour for output grib files
            run_date: Date of input forecast grib file
            map_: Map associated with the input/output grib files
            out_path: path to where output grib files are stored
            radius: circular radius from each point in km
            smoothing: width of Gaussian smoother in km
            threshold: intensity of exceedance
            stride: number of grid points to skip for reduced neighborhood grid
    Returns:
            Grib2 files of neighborhood probability and ensemble maximum 
            over specified time period 
        """
    date_outpath=out_path+'{0}/'.format(forecast_grid.start_date.strftime("%Y%m%d"))
    if not exists(date_outpath):
        os.makedirs(date_outpath)

    ens_max = forecast_grid.data.max(axis=0)
    ens_max = ens_max[start_time:end_time,:,:].max(axis=0)
    grib = grib2_ouput(forecast_grid,start_time,ens_max,map_,run_date,date_outpath,"Ensemble-Maximum",
                       (start_time+12),(end_time+12),1)

    for threshold in thresholds:
        neighbor_prob = forecast_grid.period_neighborhood_probability(radius,smoothing, threshold, 
                                                                      stride,start_time,end_time)
        neighbor_prob = neighbor_prob.mean(axis=0)
        grib = grib2_ouput(forecast_grid,start_time,neighbor_prob,map_,run_date,date_outpath,
                           "NEP_{0}mm".format(threshold),(start_time+12),(end_time+12),14)

    return                                                         

def output_netcdf_files(forecast_grid,start_time,end_time,out_path,radius,
                        smoothing,thresholds,stride):
    """
    Calculates the neighborhood probability over a specified time period, and outputs netcdf files.
    Ensemble Maximum hail could be added in the future. 
     
    Args:
            forecast_grid: data from input forecast file
            start_time: beginning hour for output files
            end_time: end hour for output files
            out_path: path to where output files are stored
            radius: circular radius from each point in km
            smoothing: width of Gaussian smoother in km
            threshold: intensity of exceedance
            stride: number of grid points to skip for reduced neighborhood grid

    Returns:
            Netcdf files of neighborhood probability over specified time period 
    """
    date_outpath = out_path+'{0}/netcdf/'.format(forecast_grid.start_date.strftime("%Y%m%d"))

    if not exists(date_outpath):
        os.makedirs(date_outpath)

    for threshold in thresholds:
        neighbor_prob = forecast_grid.period_neighborhood_probability(radius,smoothing, threshold,
                                                                     stride,start_time,end_time)
        neighbor_prob = neighbor_prob.mean(axis=0)

        n_lon = forecast_grid.lon[::stride, ::stride]
        n_lat = forecast_grid.lat[::stride, ::stride]

        netcdf_file = netcdf_output(forecast_grid, date_outpath, neighbor_prob,
                                    n_lon, n_lat, "NEP_{0}".format(threshold),
                                    (start_time+12), (end_time+12))

    return


def get_sub_basemap(forecast_grid, stride=14):
    bmap = Basemap(projection=forecast_grid.projparams["proj"], resolution="l",
                   rsphere=forecast_grid.projparams["a"],
                   lon_0=forecast_grid.projparams["lon_0"],
                   lat_0=forecast_grid.projparams["lat_0"],
                   lat_1=forecast_grid.projparams["lat_1"],
                   lat_2=forecast_grid.projparams["lat_2"],
                   llcrnrlon=forecast_grid.lon[0, 0],
                   llcrnrlat=forecast_grid.lat[0, 0],
                   urcrnrlon=forecast_grid.lon[::stride, ::stride][-1, -1],
                   urcrnrlat=forecast_grid.lat[::stride, ::stride][-1, -1])
    return bmap


def get_full_basemap(forecast_grid):
    bmap = Basemap(projection=forecast_grid.projparams["proj"], resolution="l",
                   rsphere=forecast_grid.projparams["a"],
                   lon_0=forecast_grid.projparams["lon_0"],
                   lat_0=forecast_grid.projparams["lat_0"],
                   lat_1=forecast_grid.projparams["lat_1"],
                   lat_2=forecast_grid.projparams["lat_2"],
                   llcrnrlon=forecast_grid.lon[0, 0],
                   llcrnrlat=forecast_grid.lat[0, 0],
                   urcrnrlon=forecast_grid.lon[-1, -1],
                   urcrnrlat=forecast_grid.lat[-1, -1])
    return bmap


def plot_hourly_ensemble_max():
    return


def plot_hourly_neighborhood_probability():
    return

def plot_subplots(start_time,end_time,valid_date, forecast_grid, bmap, out_path, plot_mode, ensemble_members,
                  contours=np.concatenate([[1] + np.arange(5, 80, 5)]),cmap="inferno"):
    cmap = plt.get_cmap("Paired")
    colors = cmap(np.linspace(0, 0.5, cmap.N // 2))
    cmap2 = LinearSegmentedColormap.from_list("Lower Half Paired", colors)
    members_in_half = round(forecast_grid.data.shape[0]/2.0)
    seperate_images = [list(range(0,members_in_half)), list(range(members_in_half,forecast_grid.data.shape[0]))]

    for images in seperate_images:
        f = plt.figure(figsize=(25,25))
        plt.subplots_adjust(left=0.2, bottom=0.1, right=0.9, top=0.96, wspace=0.01, hspace=0.05)
        for n, mem in enumerate(images,1):
            ax = f.add_subplot(4, 2, n)
            bmap.drawstates()
            bmap.drawcountries()
            bmap.drawcoastlines()
            x, y = bmap(forecast_grid.lon, forecast_grid.lat)
            data = ax.contourf(x, y, forecast_grid.data[mem,start_time:end_time,:,:].max(axis=0),
                        contours, cmap=cmap2, extend="max")
            ax.set_title(ensemble_members[mem])
        plt.colorbar(data, orientation='horizontal',shrink=0.7, fraction=0.05, pad=0.02)

        f.suptitle("{0} {1} {2} {3} (mm), Valid {3} {4}-{5} UTC".format(forecast_grid.ensemble_name,
                                                                        forecast_grid.ml_model.replace("-", " "),
                                                                        plot_mode,
                                                                        forecast_grid.variable.capitalize(),
                                                                        valid_date,
                                                                        ((start_time+12)%24),((end_time+12)%24)),
              fontweight="bold",
              fontsize=10)


        date_outpath=out_path+'{0}/png/'.format(forecast_grid.start_date.strftime("%Y%m%d"))
        if not exists(date_outpath):
            os.makedirs(date_outpath)

        plt.savefig(date_outpath + "{0}_{1}_{2}_{3}_time_{4}_{5}_{6}.png".format(forecast_grid.ensemble_name,
                                                                            forecast_grid.ml_model,
                                                                            size_dist_mode,
                                                                            plot_mode,
                                                                            forecast_grid.start_date.strftime("%y%m%d"),
                                                                            (start_time+12),(end_time+12),
                                                                            images[0]),
                     bbox_inches="tight", dpi=300)
        plt.close()
    return



def plot_period_ensemble_max(start_time,end_time,valid_date, forecast_grid, bmap, out_path,
                              figsize=(10, 6), contours=np.concatenate([[1] + np.arange(5, 80, 5)]),
                             cmap="inferno"):
    cmap = plt.get_cmap("Paired")
    colors = cmap(np.linspace(0, 0.5, cmap.N // 2))
    cmap2 = LinearSegmentedColormap.from_list("Lower Half Paired", colors)
    ens_max = forecast_grid.data.max(axis=0)
    ens_max = ens_max[start_time:end_time,:,:].max(axis=0)
    plt.figure(figsize=figsize)
    bmap.drawstates()
    bmap.drawcountries()
    bmap.drawcoastlines()
    x, y = bmap(forecast_grid.lon, forecast_grid.lat)
    plt.contourf(x, y, ens_max, contours, cmap=cmap2, extend="max")
                             
    date_outpath=out_path+'{0}/png/'.format(forecast_grid.start_date.strftime("%Y%m%d"))
    if not exists(date_outpath):
        os.makedirs(date_outpath)
                             
    plt.title("{0} {1} Ensemble Maximum {2} (mm), Valid {3} {4}-{5} UTC".format(forecast_grid.ensemble_name,
                                                                   forecast_grid.ml_model.replace("-", " "),
                                                                   forecast_grid.variable.capitalize(),
                                                                   valid_date,
                                                                   ((start_time+12)%24),((end_time+12)%24)),
              fontweight="bold",
              fontsize=12)
    plt.colorbar(orientation="horizontal", shrink=0.7, fraction=0.05, pad=0.02)
    plt.savefig(date_outpath + "{0}_{1}_ens_max_{2}_time_{3}_{4}.png".format(forecast_grid.ensemble_name,
                                                                       forecast_grid.ml_model,
                                                                       forecast_grid.start_date.strftime("%y%m%d"),
                                                                       (start_time+12),(end_time+12)),
                bbox_inches="tight", dpi=300)
    plt.close()
    return


def plot_period_neighborhood_probability(forecast_grid, bmap, radius, smoothing, thresholds, stride, out_path,
                                         start_time,end_time,figsize=(10, 6),
                                         contours=np.concatenate((np.array([0.01, 0.05]), np.arange(0.1, 1.1, 0.1))),
                                         cmap="RdPu"):
    for threshold in thresholds:
        plt.figure(figsize=figsize)
        neighbor_prob = forecast_grid.period_neighborhood_probability(radius,smoothing,threshold,stride,start_time,end_time)
        
        n_lon = forecast_grid.lon[::stride, ::stride]
        n_lat = forecast_grid.lat[::stride, ::stride]
        neighbor_prob = neighbor_prob.mean(axis=0)
        
        date_outpath=out_path+'{0}/png/'.format(forecast_grid.start_date.strftime("%Y%m%d"))
        if not exists(date_outpath):
            os.makedirs(date_outpath)

        n_x, n_y = bmap(n_lon, n_lat)
        bmap.drawstates()
        bmap.drawcoastlines()
        bmap.drawcountries()
        
        
        cmap = matplotlib.colors.ListedColormap(['#DBC6BD','#AD8877','#FCEA8D', 
                                                'gold','#F76E67','#F2372E',
                                                '#F984F9','#F740F7','#AE7ADD','#964ADB',
                                                '#99CCFF', '#99CCFF','#3399FF'])
                
        levels = [0.01,0.05,0.15, 0.225, 0.30, 0.375, 0.45, 0.525, 0.60, 0.70, 0.8, 0.9, 1.0]
           
        plt.contourf(n_x, n_y, neighbor_prob,extend="max", cmap=cmap,levels=levels)
        cbar = plt.colorbar(orientation="horizontal", shrink=0.7, fraction=0.05, pad=0.02)
        cbar.set_ticks([0.01,0.05,0.15, 0.30, 0.45, 0.60, 0.80, 1.0])
        cbar.set_ticklabels([1,5,15,30,45,60,80,100])

        plt.title("{0} {1} Ens. Neighbor Prob. of {2}>{3:d} mm\nR={4:d} km, $\sigma$={5:d} km, Valid {6}".format(
            forecast_grid.ensemble_name,
            forecast_grid.ml_model.replace("-", " "),
            forecast_grid.variable.capitalize(),
            int(threshold),
            int(radius),
            int(smoothing*3),
            forecast_grid.start_date.strftime("%d %b %Y")),
            fontweight="bold",
            fontsize=10)
        filename = date_outpath + "{0}_{1}_NEP_{2:d}_r_{3:d}_s_{4:d}_{5}_time_{6}_{7}.png".format(
            forecast_grid.ensemble_name,
            forecast_grid.ml_model,
            int(threshold),
            int(radius),
            int(smoothing*3),
            forecast_grid.start_date.strftime("%y%m%d"),
            (start_time+12),(end_time+12))
        plt.savefig(filename,
                    bbox_inches="tight", dpi=300)
        plt.close()
    return
                                     
def grib2_ouput(forecast_grid,start_time,data,map_,run_date,path,plot_mode,start,end,stride):
    """
    Writes out grib2 files for given Ensemble Maximum and Neighborhood Probability numpy array data. 

    Plot_mode should be either Ensemble Maximum or Neighborhood Probability 
    Time_mode should be hourly, total, or incremental. 
    """
    if map_[-3:] == "map":                                  
        proj_dict, grid_dict = read_arps_map_file(map_)
    else:                                               
        proj_dict, grid_dict = read_ncar_map_file(map_)   


    lscale = 1e6
    grib_id_start = [7, 0, 14, 14, 2]
    gdsinfo = np.array([0, np.product(data.shape[-2:]), 0, 0, 30], dtype=np.int32)
    lon_0 = proj_dict["lon_0"]
    sw_lon = grid_dict["sw_lon"]
    if lon_0 < 0:
        lon_0 += 360
    if sw_lon < 0:
        sw_lon += 360

    gdtmp1 = [1, 0, proj_dict['a'], 0, float(proj_dict['a']), 0, float(proj_dict['b']),
            data.shape[-1], data.shape[-2], grid_dict["sw_lat"] * lscale,
            sw_lon * lscale, 0, proj_dict["lat_0"] * lscale,
            lon_0 * lscale,
            grid_dict["dx"] * 1e3 * stride, grid_dict["dy"] * 1e3 * stride, 0b00000000, 0b01000000,
            proj_dict["lat_1"] * lscale,
            proj_dict["lat_2"] * lscale, -90 * lscale, 0]
    pdtmp1 = np.array([1, 31, 4, 0, 31, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1], dtype=np.int32)
    drtmp1 = np.array([0, 0, 4, 8, 0], dtype=np.int32)
    time_list = list(run_date.utctimetuple()[0:6])
    grib_objects = Grib2Encode(0, np.array(grib_id_start + time_list + [2, 1], dtype=np.int32))
    grib_objects.addgrid(gdsinfo, gdtmp1)
    pdtmp1[8] = start_time+12
    pdtmp1[-2] = 0
    grib_objects.addfield(1, pdtmp1, 0, drtmp1, data)
    grib_objects.end()
    filename = path + "{0}_Hail_{1}_{2}_{3}_Hours_{4}-{5}.grib2".format(forecast_grid.ensemble_name,
                                                         forecast_grid.ml_model,
                                                         plot_mode,
                                                         forecast_grid.start_date.strftime("%y%m%d"),
                                                         start,end)
    print("Writing to " + filename)
    grib_file = open(filename, "wb")
    grib_file.write(grib_objects.msg)
    grib_file.close()

    return

def netcdf_output(forecast_grid, path, data, data_lon, data_lat, output_mode, start, end):

    out_filename = path + "{0}_Hail_{1}_{2}_{3}_Hours_{4}-{5}.nc".format(forecast_grid.ensemble_name,
                                                                        forecast_grid.ml_model,
                                                                        output_mode,
                                                                        forecast_grid.start_date.strftime("%y%m%d"),
                                                                        start,end)
    out_file = Dataset(out_filename, "w")
    out_file.createDimension("x", data.shape[0])
    out_file.createDimension("y", data.shape[1])
    out_file.createVariable("Longitude", "f4", ("x", "y"))
    out_file.createVariable("Latitude", "f4",("x", "y"))
    out_file.createVariable("Data", "f4", ("x", "y"))
    out_file.variables["Longitude"][:,:] = data_lon
    out_file.variables["Latitude"][:,:] = data_lat
    out_file.variables["Data"][:,:] = data
    out_file.projection = forecast_grid.projparams["proj"]
    out_file.lon_0 = forecast_grid.projparams["lon_0"]
    out_file.lat_0 = forecast_grid.projparams["lat_0"]
    out_file.lat_1 = forecast_grid.projparams["lat_1"]
    out_file.lat_2 = forecast_grid.projparams["lat_2"]
    out_file.close()

    print("Writing to " + out_filename)

    return

if __name__ == "__main__":
    main()
