#!/usr/bin/env python3

import os
import re
import tqdm
import glob
import math
import numpy
import pprint
import tarfile
import argparse

import pandas            as pnd
import utils_noroot      as utnr
import matplotlib.pyplot as plt

from logzero     import logger as log
from scipy.stats import sem    as spy_sem
#-----------------------------------
class data:
    nrow      = None
    res_dir   = None
    out_dir   = None
    var       = 'rk'
    d_var_err = {}
    d_var_val = {}
    d_var_name= None
    d_fix_var = None
#-------------------------------------------------------
def get_var_naming():
    d_name                                                    = {}
    d_name['ee_all_TOS_toys_fit_psi2S_ratio_2017_ETOS_nom_0'] = '$r_{2017}^{PRec}$'
    d_name['ee_all_TOS_toys_fit_psi2S_ratio_2018_ETOS_nom_0'] = '$r_{2018}^{PRec}$' 
    d_name['ee_all_TOS_toys_fit_psi2S_ratio_r1_ETOS_nom_0']   = '$r_{R1}^{PRec}$' 
    d_name['ee_all_TOS_toys_fit_psi2S_ratio_r2p1_ETOS_nom_0'] = '$r_{R2.1}^{PRec}$' 
    d_name['ncb_ee_all_TOS_toys_fit']                         = '$N_{cmb}^{ee}$' 
    d_name['nom_ee_all_TOS_toys_fit_dmu_ee_2018_ETOS']        = '$\Delta\mu_{cmb}^{ee}$' 
    d_name['nom_ee_all_TOS_toys_fit_r0_ee_2018_ETOS']         = '$r_{brem}^{0}$' 
    d_name['nom_ee_all_TOS_toys_fit_r1_ee_2018_ETOS']         = '$r_{brem}^{1}$' 
    d_name['nom_ee_all_TOS_toys_fit_r2_ee_2018_ETOS']         = '$r_{brem}^{2}$' 
    d_name['nom_ee_all_TOS_toys_fit_ssg_ee_2018_ETOS']        = '$s_{\sigma}^{ee}$' 
    d_name['npr_ee_all_TOS_toys_fit']                         = '$N_{cc PRec}^{ee}$' 
    d_name['ncb_mm_all_TOS_toys_fit']                         = '$N_{cmb}^{\mu\mu}$'
    d_name['nom_mm_all_TOS_toys_fit_dmu_mm_2018_MTOS']        = '$\Delta\mu_{cmb}^{\mu\mu}$' 
    d_name['nom_mm_all_TOS_toys_fit_ssg_mm_2018_MTOS']        = '$s_{\sigma}^{\mu\mu}$' 
    d_name['ck']                                              = '$c_{K}$'
    d_name['None']                                            = 'None'

    return d_name
#-----------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='This script makes plots and tables to assess the sources of systematics on RK')
    parser.add_argument('-d','--res_dir', type=str, help='Directory with the output of systematics jobs')
    parser.add_argument('-o','--out_dir', type=str, help='Directory with plots', default='plots') 
    args = parser.parse_args()

    data.res_dir = args.res_dir
    data.out_dir = args.out_dir

    try:
        os.makedirs(data.out_dir, exist_ok=True)
    except:
        log.error(f'Cannot create: {data.out_dir}')
        raise
#-----------------------------------
def check_job(l_jsn):
    if   data.nrow is None:
        data.nrow = len(l_jsn)
    elif len(l_jsn) != data.nrow:
        log.error(f'Found incompatible number of JSON files in {jsn_wc}')
        raise
    else:
        pass
#-----------------------------------
def read_data(json_path):
    d_data     = utnr.load_json(json_path)
    d_pos      = d_data['pos']
    [val, err] = d_pos[data.var]

    return val, err
#-----------------------------------
def get_var_name(jsn_path):
    file_name = os.path.basename(jsn_path) 
    var_name  = file_name.replace('result_', '')
    var_name  = var_name.replace('.json'  , '')

    if re.match('\d{4}', var_name):
        var_name = 'None'

    return var_name
#-----------------------------------
def check_size(cnt, msg):
    if len(cnt) == 0:
        log.error(msg)
        raise
#-----------------------------------
def add_val(var, qty, d_dat):
    if var in d_dat:
        d_dat[var].append(qty)
    else:
        d_dat[var] = [qty]
#-----------------------------------
def update_data(dir_path):
    jsn_wc = f'{dir_path}/*.json'
    l_jsn  = get_json_from_wc(jsn_wc)
    check_size(l_jsn, f'Empty list of JSON files in: {jsn_wc}')

    check_job(l_jsn)

    d_var = { get_var_name(jsn_path) : jsn_path for jsn_path in l_jsn }
    check_size(d_var, 'Empty variable -> JSON path dictionary')

    for var, json_path in d_var.items():
        val, err = read_data(json_path)

        if ':'     in var or var == 'None':
            add_val(var, val, data.d_var_val)

        if ':' not in var or var == 'None':
            add_val(var, err, data.d_var_err)

    check_size(data.d_var_err, 'Empty variable -> error dictionary')
    check_size(data.d_var_val, 'Empty variable -> value dictionary')
#-----------------------------------
def plot_error():
    df = pnd.DataFrame(columns=['Variable', 'Value', 'Error'])

    for var, l_error in data.d_var_err.items():
        name = data.d_var_name[var]
        plt.hist(l_error, range=(0, 0.2), bins=30, label=var)
        mu=numpy.mean(l_error)

        sg=spy_sem(l_error)
        plt.axvline(x=mu   , color='red', linestyle='--')
        plt.axvline(x=mu-sg, color='red', linestyle=':')
        plt.axvline(x=mu+sg, color='red', linestyle=':')
        plot_path = f'{data.out_dir}/rk_err_{var}.png'
        plt.legend(['Error', f'$\mu={mu:.3f}$'])
        plt.title(name)
        plt.xlabel(r'$\varepsilon(R_{K})$')
        plt.savefig(plot_path)
        plt.close('all')

        df = utnr.add_row_to_df(df, [name, mu, sg])

    df = df.sort_values(by=['Value'], ascending=True)
    df = df.reset_index(drop=True)

    return df
#-----------------------------------
def plot_value():
    df = pnd.DataFrame(columns=['Variable', 'Value', 'Error'])

    for var, l_value in data.d_var_val.items():
        name = data.d_var_name[var]
        plt.hist(l_value, range=(0, 2), bins=30, label=var)
        mu=numpy.mean(l_value)

        sg=spy_sem(l_value)
        plt.axvline(x=mu   , color='red', linestyle='--')
        plt.axvline(x=mu-sg, color='red', linestyle=':')
        plt.axvline(x=mu+sg, color='red', linestyle=':')
        plot_path = f'{data.out_dir}/rk_val_{var}.png'
        plt.legend(['$R_K$', f'$\mu={mu:.3f}$'])
        plt.title(name)
        plt.xlabel(r'$R_{K}$')
        plt.savefig(plot_path)
        plt.close('all')

        df = utnr.add_row_to_df(df, [name, mu, sg])

    df = df.sort_values(by=['Value'], ascending=True)
    df = df.reset_index(drop=True)

    return df
#-----------------------------------
def untar(tar_path):
    if not os.path.isfile(tar_path):
        return False

    dir_path = os.path.dirname(tar_path)
    if os.path.isdir(f'{dir_path}/result_jsn'):
        return True

    with tarfile.open(tar_path) as itar:
        itar.extractall(path=dir_path)
        return True 
#-----------------------------------
def get_nfix(file_path):
    file_name = os.path.basename(file_path)
    rgx       = r'result_\d+_(\d{4})_nominal\.json'

    mtch = re.match(rgx, file_name)
    if not mtch:
        log.error(f'Cannot extract number of fixed parameters from: "{file_name}"')
        raise

    return mtch.group(1)
#-----------------------------------
def merge_sys(d_json, d_nom, dir_path):
    [rk_nom, _] = d_nom['pos']['rk']

    for sys, l_json_path in d_json.items():
        l_d_par   = [ utnr.load_json(json_path) for json_path in l_json_path ]
        l_rk_val  = [ d_par['pos']['rk'][0]     for     d_par in     l_d_par ]
        rk_sys    = max(l_rk_val, key=lambda rk_val : abs(rk_val - rk_nom))
        i_var     = l_rk_val.index(rk_sys)
        d_par_sys = l_d_par[i_var]

        json_path = f'{dir_path}/{sys}.json'
        utnr.dump_json(d_par_sys, json_path)
        log.debug(f'Largest variation: {l_json_path[i_var]} -> {json_path}')
#-----------------------------------
def merge_bts(d_json, d_nom, dir_path):
    l_par        = d_nom['pos']['par']
    for sys, l_json_path in d_json.items():
        d_mrg        = {}
        d_mrg['pre'] = d_nom['pre']
        d_mrg['pos'] = {} 

        l_d_par = [ utnr.load_json(json_path) for json_path in l_json_path ]
        for par in l_par:
            par_sys           = merge_bts_par(par, l_d_par, d_nom)
            d_mrg['pos'][par] = [par_sys, 0]

        json_path = f'{dir_path}/{sys}:bts.json'
        #utnr.dump_json(d_mrg, json_path)
        log.debug(f'Mergin bootstrapped parameters into: {json_path}')
#-----------------------------------
def merge_bts_par(par_name, l_d_par, d_nom):
    l_par_val = [ d_par['pos'][par_name][0] for d_par in l_d_par ]
    par_nom   = d_nom['pos'][par_name][0]

    l_var_val = [ (par_val - par_nom) ** 2 for par_val in l_par_val ]
    var_avg   = sum(l_var_val) / len(l_d_par) 

    par_sys   = par_nom + math.sqrt(var_avg)

    return par_sys
#-----------------------------------
def get_toy_sys_df(json_path, itoy):
    d_par = utnr.load_json(json_path)
    d_pre = d_par['pre']
    d_pos = d_par['pos']

    l_var = list(d_pre.keys())

    l_gen = [ d_pre[var][0]                                    for var in l_var ]
    l_fit = [ d_pos[var][0] if var in d_pos else d_pre[var][0] for var in l_var ]
    l_err = [ d_pos[var][1] if var in d_pos else             0 for var in l_var ]

    d_dat = {'var' : l_var, 'gen' : l_gen, 'fit' : l_fit, 'err' : l_err}
    df    = pnd.DataFrame(d_dat)
    nfix  = get_nfix(json_path)
    df['sys'] = data.d_fix_var[nfix] 
    df['toy'] = itoy 

    return df
#-----------------------------------
def get_fix_var(l_json):
    if data.d_fix_var is not None:
        return

    l_json.sort()

    s_var = None
    for json_path in l_json: 
        nfix   = get_nfix(json_path)

        d_jdat = utnr.load_json(json_path)
        d_pos  = d_jdat['pos']
        l_key  = list(d_pos.keys())
        s_vft  = { key for key in l_key if key in d_jdat['pre'] }

        if s_var is None:
            s_var            = s_vft
            data.d_fix_var   = { nfix : 'None' }
            continue

        svar_drop            = s_var.difference(s_vft)
        [var_drop]           = list(svar_drop)
        s_var                = s_vft
        data.d_fix_var[nfix] = var_drop
#-----------------------------------
def get_toy_df(dir_path, i_toy):
    json_wc = f'{dir_path}/*.json'
    l_json  = glob.glob(json_wc)
    l_json  = [ file_path                        for file_path in l_json if file_path.endswith('_nominal.json') ]
    get_fix_var(l_json)
    l_df    = [ get_toy_sys_df(json_path, i_toy) for json_path in l_json ]

    return l_df 
#-----------------------------------
def get_df():
    json_path = f'{data.out_dir}/data.json'
    if os.path.isfile(json_path):
        log.info(f'Cached data found loading: {json_path}')
        df = pnd.read_json(json_path)
        return df

    df    = pnd.DataFrame(columns=['toy', 'var', 'sys', 'gen', 'fit', 'err'])
    l_obj = glob.glob(f'{data.res_dir}/*')
    l_dir = [ obj for obj in l_obj if re.match('\d{8}', os.path.basename(obj)) ]
    l_tar = [ f'{dir_path}/result_jsn.tar.gz' for dir_path in l_dir ]

    log.info(f'Checking inputs and untarring')
    l_dir = [ tar_file.replace('.tar.gz', '') for tar_file in tqdm.tqdm(l_tar, ascii=' -') if untar(tar_file) ]

    l_df = []
    i_toy = 0
    for dir_path in tqdm.tqdm(l_dir, ascii=' -'):
        l_df += get_toy_df(dir_path, i_toy)
        i_toy+= 1

    df = pnd.concat(l_df)
    df = df.reset_index(drop=True)

    log.info(f'Caching data to: {json_path}')
    df.to_json(json_path, indent=4)

    return df
#-----------------------------------
def plot_rk_sys(df):
    l_nam   = []
    l_unc   = []

    fig, ax = plt.subplots(figsize=(10,6))
    for var, df_var in df.groupby('var'):
        if var != 'rk':
            continue
        ax = None
        for sys, df_sys in df_var.groupby('sys'):
            mu   = df_sys.err.mean()
            name = data.d_var_name[sys]
            ax   = df_sys.err.hist(bins=40, range=[0, 0.3], histtype='step', label=f'{name}; $\mu={mu:.3f}$', ax=ax)

            l_nam.append(name)
            l_unc.append(mu)

        plt.legend()
        plt.xlabel(r'$\varepsilon(R_K)$')
        plt.savefig(f'{data.out_dir}/rk_dst.png')
        plt.close('all')

    fig, ax = plt.subplots(figsize=(10,6))
    plt.plot(l_nam, l_unc)
    plt.grid()
    plt.ylabel(r'$\varepsilon(R_K)$')
    plt.tight_layout()
    plt.xticks(rotation=60)
    plt.savefig(f'{data.out_dir}/rk_sys.png')
    plt.close('all')
#-----------------------------------
def main():
    data.d_var_name = get_var_naming()
    df = get_df()

    plot_rk_sys(df)
#-----------------------------------
if __name__ == '__main__':
    get_args()
    main()

