#!/usr/bin/env python3

import os
import re
import glob
import numpy
import pprint
import argparse

import pandas            as pnd
import utils_noroot      as utnr
import matplotlib.pyplot as plt

from logzero import logger as log
#-----------------------------------
class data:
    nrow      = None
    res_dir   = None
    out_dir   = None
    var       = 'rk'
    d_var_err = {}
    d_var_name= None
#-------------------------------------------------------
def get_var_naming():
    d_name             = {}
    d_name['rsg_mm']   = '$r_{\sigma}^{\mu\mu}$'
    d_name['rsg_ee']   = '$r_{\sigma}^{ee}$'
    d_name['rsg']      = '$r_{\sigma}$'

    d_name['nsg_mm']   = '$N_{signal}^{\mu\mu}$'
    d_name['nsg_ee']   = '$N_{signal}^{ee}$'
    d_name['nsg'   ]   = '$N_{signal}$'

    d_name['ncb_mm']   = '$N_{comb}^{\mu\mu}$'
    d_name['ncb_ee']   = '$N_{comb}^{ee}$'
    d_name['ncb']      = '$N_{comb}$'

    d_name['mu_cb_mm'] = '$\mu_{comb}^{\mu\mu}$'
    d_name['mu_cb_ee'] = '$\mu_{comb}^{ee}$'
    d_name['mu_cb']    = '$\mu_{comb}$'

    d_name['lm_cb_mm'] = '$\lambda_{comb}^{\mu\mu}$'
    d_name['lm_cb_ee'] = '$\lambda_{comb}^{ee}$'
    d_name['lm_cb']    = '$\lambda_{comb}$'

    d_name['dmu_mm']   = '$\Delta(\mu)^{\mu\mu}$'
    d_name['dmu_ee']   = '$\Delta(\mu)^{ee}$'
    d_name['dmu']      = '$\Delta(\mu)$'

    d_name['ck']       = '$c_{K}$'
    d_name['None']     = 'None'

    return d_name
#-----------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='This script makes plots and tables to assess the sources of systematics on RK')
    parser.add_argument('-d','--res_dir', type=str, help='Directory with the output of systematics jobs')
    parser.add_argument('-o','--out_dir', type=str, help='Directory with plots', default='plots') 
    args = parser.parse_args()

    data.res_dir = args.res_dir
    data.out_dir = args.out_dir

    try:
        os.makedirs(data.out_dir, exist_ok=True)
    except:
        log.error(f'Cannot create: {data.out_dir}')
        raise
#-----------------------------------
def check_job(l_jsn):
    if   data.nrow is None:
        data.nrow = len(l_jsn)
    elif len(l_jsn) != data.nrow:
        log.error(f'Found incompatible number of JSON files in {jsn_wc}')
        raise
    else:
        pass
#-----------------------------------
def read_data(json_path):
    d_data     = utnr.load_json(json_path)
    d_pos      = d_data['pos']
    [_, err] = d_pos[data.var]

    return err
#-----------------------------------
def get_var_name(jsn_path):
    file_name = os.path.basename(jsn_path) 
    var_name  = file_name.replace('result_', '')
    var_name  = var_name.replace('.json'  , '')

    if re.match('\d{4}', var_name):
        var_name = 'None'

    return var_name
#-----------------------------------
def update_data(dir_path):
    jsn_wc = f'{dir_path}/results/result_jsn/*.json'
    l_jsn  = glob.glob(jsn_wc)
    check_job(l_jsn)

    d_var = { get_var_name(jsn_path) : jsn_path for jsn_path in l_jsn }

    for var, json_path in d_var.items():
        err = read_data(json_path)
        if var in data.d_var_err:
            data.d_var_err[var].append(err)
        else:
            data.d_var_err[var] = [err]
#-----------------------------------
def plot_error():
    df = pnd.DataFrame(columns=['Variable', 'Value', 'Error'])

    for var, l_error in data.d_var_err.items():
        name = data.d_var_name[var]
        plt.hist(l_error, range=(0, 0.5), bins=30, label=var)
        mu=numpy.mean(l_error)
        sg=numpy.std(l_error)
        plt.axvline(x=mu   , color='red', linestyle='--')
        plt.axvline(x=mu-sg, color='red', linestyle=':')
        plt.axvline(x=mu+sg, color='red', linestyle=':')
        plot_path = f'{data.out_dir}/{var}.png'
        plt.title(name)
        plt.xlabel(r'$\varepsilon(R_{K})$')
        plt.savefig(plot_path)
        plt.close('all')

        df = utnr.add_row_to_df(df, [name, mu, sg])

    df = df.sort_values(by=['Value'], ascending=True)
    df = df.reset_index(drop=True)

    return df
#-----------------------------------
def get_df():
    df = pnd.DataFrame(columns=['var', 'val', 'err'])

    l_obj = glob.glob(f'{data.res_dir}/*')
    l_dir = [ obj for obj in l_obj if re.match('\d{3}_001', os.path.basename(obj)) ]

    njob=len(l_dir)
    if njob == 0:
        log.error(f'Found {njob} jobs in {data.res_dir}')
        raise

    log.debug(f'Found {njob} jobs')
    for dir_path in l_dir:
        update_data(dir_path) 

    df=plot_error()

    return df
#-----------------------------------
def plot_df(df):
    ax=df.plot(x='Variable', y='Value', xerr='Error', kind='barh', color='none', capsize=5)
    plt.gca().set_xlim(0, 0.5)
    plt.legend([])
    plt.ylabel('')
    plt.xlabel(r'$\varepsilon(R_K)$')
    plt.grid()

    for i_val, val in enumerate(df.Value):
        plt.plot(val, i_val, 'o', color='red')

    plt.savefig(f'{data.out_dir}/systematics.png')
    plt.close('all')
#-----------------------------------
def main():
    data.d_var_name = get_var_naming()

    df = get_df()
    plot_df(df)
#-----------------------------------
if __name__ == '__main__':
    get_args()
    main()

