#!/usr/bin/env python3

import os
import glob
import pandas            as pnd
import argparse
import matplotlib.pyplot as plt

from logzero import logger as log

#----------------------------------------
class data:
    jsn_dir = None
    out_dir = None
#----------------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Used to make plots from JSON files from model checking scripts')
    parser.add_argument('-j', '--jsn_dir', type=str, help='Directory where JSON files go' , required=True)
    parser.add_argument('-p', '--plt_dir', type=str, help='Directory where plots will go' , required=True)
    args = parser.parse_args()

    data.jsn_dir = args.jsn_dir
    data.out_dir = f'{args.plt_dir}/plots'

    os.makedirs(data.out_dir, exist_ok=True)
#----------------------------------------
def load_data(kind):
    jsn_wc = f'{data.jsn_dir}/{kind}_*.json'
    l_json = glob.glob(jsn_wc)
    if len(l_json) == 0:
        log.error(f'Found no JSON file in {jsn_wc}')
        raise

    l_df   = [ pnd.read_json(json_path) for json_path in l_json ]
    df     = pnd.concat(l_df)
    df     = df.reset_index(drop=True)

    return df
#----------------------------------------
def main():
    get_args()
    df_val = load_data('val')
    df_err = load_data('err')
    df_ini = load_data('ini')

    plot_vals(df_ini, df_val)
    plot_pulls(df_ini, df_val, df_err)
#---------------------------------
def plot_pulls(df_ini, df_val, df_err):
    for nam in df_val.columns: 
        pull = (df_val[nam] - df_ini[nam]) / df_err[nam]
        pull.hist(bins=30, range=[-6, +6])
        pull_path=f'{data.out_dir}/pul_{nam}.png'
        log.info(f'Saving to: {pull_path}')
        plt.savefig(pull_path)
        plt.close('all')
#---------------------------------
def plot_vals(df_ini, df_val):
    for nam in df_val.columns: 
        ini = df_ini.loc[0, nam]
        df_val[nam].hist(bins=30) 
        plt.axvline(x=ini, color='r')
        val_path=f'{data.out_dir}/val_{nam}.png'
        log.info(f'Saving to: {val_path}')
        plt.savefig(val_path)
        plt.close('all')
#----------------------------------------
if __name__ == '__main__':
    main()

