#!/usr/bin/env python3

import glob
import os
import re
import tqdm
import shutil
import tarfile
import argparse

import pandas            as pnd
import matplotlib.pyplot as plt
import utils_noroot      as utnr

from logzero import logger as log

#-------------------------------------------------------
class data:
    job_name = None
    out_path = None
#-------------------------------------------------------
def rename_jsn():
    os.makedirs(f'{data.out_path}/output', exist_ok=True)
    for jsn_path in glob.glob('result_jsn/*.json'):
        jsn_name = os.path.basename(jsn_path)
        os.replace(jsn_path, f'{data.out_path}/output/{jsn_name}')
#-------------------------------------------------------
def untar(tar_path):
    tar = tarfile.open(tar_path)
    tar.extractall()
    tar.close()
#-------------------------------------------------------
def make_json():
    '''
    Will take tarballs, untar them, and put all the JSON files in output directory
    '''

    if not data.job_name:
        return

    if os.path.isdir(f'{data.out_path}/output'):
        log.info('JSON directory found, not making it')
        return

    l_dirname = [ dirname for dirname in glob.glob(f'{data.out_path}/*') if re.match(f'{data.out_path}/'+ '\d{9}', dirname)]
    if len(l_dirname) == 0:
        log.error(f'Found no sandboxes in {data.out_path}')
        raise

    l_tar_path= [ f'{dirname}/result_jsn.tar.gz' for dirname in l_dirname if os.path.isfile( f'{dirname}/result_jsn.tar.gz')]

    if len(l_tar_path) == 0:
        log.error(f'Found no tarballs for {data.job_name}')
        raise

    log.info('Unpacking JSON files')
    for tar_path in tqdm.tqdm(l_tar_path, ascii=' -'):
        try:
            untar(tar_path)
        except tarfile.ReadError:
            log.warning(f'Cannot untar: {tar_path}')
            continue
        rename_jsn()
        shutil.rmtree('result_jsn')
#-------------------------------------------------------
def get_data(json_path):
    '''
    Takes path to result_xxxx.json and returns dictionary with {str : float}
    mapping of parameters, etc
    '''
    d_data = utnr.load_json(json_path)

    d_data_pars = {key : val        for key, val in d_data.items() if isinstance(val, list) and len(val) == 2}
    d_data_meta = {key : float(val) for key, val in d_data.items() if isinstance(val, str)}

    d_data_rename = {}
    for name, [val, err] in d_data_pars.items():
        d_data_rename[f'{name} value'] = [float(val)]
        d_data_rename[f'{name} error'] = [float(err)]

    d_data_rename.update(d_data_meta)

    return d_data_rename
#-------------------------------------------------------
def get_df():
    json_wc = f'{data.out_path}/output/*.json' if data.job_name is not None else f'{data.out_path}/*/results/result_jsn/*.json'
    l_json_path  = glob.glob(json_wc)
    if len(l_json_path) == 0:
        log.error(f'No JSON file found in: {json_wc}')
        raise

    l_df = [ pnd.DataFrame(get_data(json_path)) for json_path in l_json_path ]

    df   = pnd.concat(l_df, axis=0)
    df   = df.reset_index(drop=True)

    return df
#-------------------------------------------------------
def plot(df):
    os.makedirs(f'{data.out_path}/plots', exist_ok = True)
    plot_pull(df)

    plot_var(df, 'rk value' , rng=(-3, +3))
    plot_var(df, 'rk error' , rng=(0, 1))
    plot_var(df, 'converged', rng=(0, 2))
    plot_var(df, 'valid'    , rng=(0, 2))
    plot_var(df, 'status'   ) 
#-------------------------------------------------------
def plot_var(df, name, rng=None):
    df.plot.hist(column=[name], bins=50, range=rng, histtype='step')

    name_nospa = name.replace(' ', '_')
    plot_path = f'{data.out_path}/plots/{name_nospa}.png'
    log.info(f'Saving to: {plot_path}')
    plt.savefig(plot_path)
    plt.close('all')
#-------------------------------------------------------
def plot_pull(df):
    df['pull'] = (df['rk value'] - 1 ) / df['rk error']

    mu = df.pull.mean()
    sd = df.pull.std()

    df.plot.hist(column=['pull'], bins=50, range=(-3, +3), histtype='step')
    plt.axvline(x=mu   , color='red', linestyle='--')
    plt.axvline(x=mu-sd, color='red', linestyle=':')
    plt.axvline(x=mu+sd, color='red', linestyle=':')

    plot_path=f'{data.out_path}/plots/pull.png'
    log.info(f'Saving to: {plot_path}')
    plt.savefig(plot_path)
    plt.close('all')
#-------------------------------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Will make plots from the results of toy fits')
    parser.add_argument('-n','--job_name', type=str, help='Name of job, for grid jobs')
    parser.add_argument('-p','--job_path', type=str, help='Path to job output, for IHEP tests')
    args = parser.parse_args()

    if args.job_name is None and args.job_path is None:
        log.error(f'Neither job name or job path passed')
        raise

    data.job_name = args.job_name
    data.out_path = f'output_{data.job_name}' if args.job_name else args.job_path
#-------------------------------------------------------
def main():
    data_name = f'{data.out_path}/data.json'
    if not os.path.isfile(data_name):
        make_json()
        df = get_df()
        df.to_json(data_name)

    df = pnd.read_json(data_name)
    plot(df)
#-------------------------------------------------------
if __name__ == '__main__':
    get_args()
    main()

