#!/usr/bin/env python3

import os
import glob
import pandas            as pnd
import argparse
import matplotlib.pyplot as plt

from logzero import logger as log

#----------------------------------------
class data:
    jsn_dir = None
    out_dir = None
#----------------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Used to make plots from JSON files from model checking scripts')
    parser.add_argument('-j', '--jsn_dir', type=str, help='Directory where JSON files go' , required=True)
    args = parser.parse_args()

    data.jsn_dir = args.jsn_dir
    data.out_dir = f'{data.jsn_dir}/plots'

    os.makedirs(data.out_dir, exist_ok=True)
#----------------------------------------
def load_data(kind):
    jsn_wc = f'{data.jsn_dir}/{kind}_*.json'
    l_json = glob.glob(jsn_wc)
    if len(l_json) == 0:
        log.error(f'Found no JSON file in {jsn_wc}')
        raise

    l_df   = [ pnd.read_json(json_path) for json_path in l_json ]
    df     = pnd.concat(l_df)

    return df
#----------------------------------------
def main():
    get_args()
    df_val = load_data('err')
    df_err = load_data('val')
    df_ini = load_data('ini')
    plot_pulls(df_ini, df_val, df_err)
#---------------------------------
def plot_pulls(df_ini, df_val, df_err):
    if data.out_dir is None:
        return

    for nam in df_val.columns: 
        pull = (df_val[nam] - df_ini[nam]) / df_err[nam]
        pull.hist(bins=30)
        pull_path=f'{data.out_dir}/{nam}.png'
        log.info(f'Saving to: {pull_path}')
        plt.savefig(pull_path)
        plt.close('all')
#----------------------------------------
if __name__ == '__main__':
    main()

