#!/usr/bin/env python3

from importlib.resources import files
from log_store           import log_store
from np_reader           import np_reader    as np_rdr
from bdt_scale           import scale_reader as scl_rdr

import utils_noroot as utnr
import rk.utilities as rkut
import pandas       as pnd
import argparse
import extset
import pprint
import os
import re

log = log_store.add_logger(name='rk_extractor:make_yields_table')
#--------------------------------------
class data:
    trig        = None 
    sb_vers     = None 
    out_dir     = None 

    bdt_bin     = 5
    bdt_eff_vers= 'v1'

    d_np_arg    = {'sys' : 'v65', 'sta' : 'v81', 'yld' : 'v24'}
#--------------------------------------
def get_name(val):
    if   val.startswith('ncb_'):
        return 'Combinatorial'
    elif val.startswith('npr_'):
        return r'$c\bar{c}_{prc}+\psi(2S)K^+$'
    elif val.startswith('nsg'):
        return 'Signal'
    elif val == 'ncnd_sign_sp':
        return 'Signal in dswp'
    elif val == 'ncnd_ctrl_sp':
        return 'Double swap'
    else:
        log.error(f'Invalid variable: {val}')
        raise
#--------------------------------------
def right_bdt_bin(key):
    mtch = re.match('.*_\d', key)
    #Old JSON files did by default signal bin only
    #did not end with _{bdt_bin}
    if not mtch:
        return True

    return key.endswith(f'_{data.bdt_bin}')
#--------------------------------------
def get_sb_data():
    json_path  = files('extractor_data').joinpath(f'sb_fits/{data.sb_vers}/all_{data.trig}.json')
    d_data     = utnr.load_json(json_path)
    d_data_flt = { key : val for key, val in d_data.items() if right_bdt_bin(key) }

    return d_data_flt
#--------------------------------------
def same_trig(key):
    is_tis = ('TIS' in key) and (data.trig ==           'GTIS')
    is_tos = ('TOS' in key) and (data.trig in ['MTOS', 'ETOS'])

    return is_tis or is_tos
#--------------------------------------
def get_eff(d_eff):
    d_eff     = { key : val for key, val in d_eff.items() if same_trig(key) }
    d_eff_flt = {}

    d_bdt_wp, _, _ = extset.get_bdt_bin_settings(bdt_bin = data.bdt_bin)

    for key, (eff_mm, eff_ee) in d_eff.items():
        dset= key.split('_')[0]
        obj = scl_rdr(wp=d_bdt_wp, version=data.bdt_eff_vers, dset=dset, trig=data.trig)
        scl = obj.get_scale()

        d_eff_flt[dset] = scl * eff_mm if data.trig == 'MTOS' else scl * eff_ee 

    return d_eff_flt
#--------------------------------------
def get_sg_data():
    rdr          = np_rdr(**data.d_np_arg)
    df_ryld      = rdr.get_ryields() 
    sr_ryld      = df_ryld[f'sign_{data.trig}']

    return {'nsg' : [sr_ryld.sum(), 0]}
#--------------------------------------
def get_df():
    d_data_sb = get_sb_data()
    d_data_sg = get_sg_data()
    d_data    = {**d_data_sb, **d_data_sg}

    d_dict = {'Component' : [], 'Yield' : []}
    for key, [val, err] in d_data.items():
        if not key.startswith('n'):
            continue
        name = get_name(key)
        d_dict['Component'].append(name)
        d_dict['Yield'].append(f'{val:.0f}$\pm${err:.0f}')

    df = pnd.DataFrame(d_dict)

    return df
#--------------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Used to make tables of yields of components for RK toy fits')
    parser.add_argument('-t', '--trig' , type=str, help='Trigger', required=True, choices=['MTOS', 'ETOS', 'GTIS'])
    parser.add_argument('-s', '--sbve' , type=str, help='Version of sideband fits', required=True) 
    args = parser.parse_args()

    data.trig    = args.trig
    data.sb_vers = args.sbve
    data.out_dir = f'tables/yields/{data.sb_vers}'
#--------------------------------------
def main():
    get_args()
    os.makedirs(data.out_dir, exist_ok=True)
    df = get_df()

    tex_path = f'{data.out_dir}/{data.trig}.tex'
    log.info(f'Saving to: {tex_path}')
    df.to_latex(tex_path, index=False)
#--------------------------------------
if __name__ == '__main__':
    main()

