#!/usr/bin/env python3

import os
import pprint
import argparse
import utils_noroot as utnr
import zutils.utils as zut

from importlib.resources import files
from rk_model   import rk_model
from normalizer import normalizer
from logzero    import logger    as log
from np_reader  import np_reader as np_rdr
from rk         import utilities as rkut

#-----------------------------
class data:
    l_dset = ['r1', 'r2p1', '2017', '2018', 'all']
    l_trig = ['MTOS', 'ETOS', 'GTIS']

    version= None
    out_dir= None
#-----------------------------
def prepare_output():
    out_dir = files('extractor_data').joinpath(data.version)
    if os.path.isdir(out_dir):
        log.error(f'Version {data.version} already found')
        raise

    os.makedirs(out_dir, exist_ok=True)

    data.out_dir = out_dir
#-----------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Used to calculate normalizations for combinatorial and PRec from data sidebands')
    parser.add_argument('-v', '--version' , type =str, help='Version of output, used to name directories', required=True)
    parser.add_argument('-t', '--trigger' , nargs='+', help='Triggers', default=data.l_trig)
    parser.add_argument('-d', '--dataset' , nargs='+', help='Datasets', default=data.l_dset)
    args = parser.parse_args()
    
    data.version = args.version
    data.l_dset  = args.dataset
    data.l_trig  = args.trigger
#-----------------------------
def get_model(dset):
    rdr          = np_rdr(sys='v65', sta='v63', yld='v24')
    rdr.cache    = True
    d_eff        = rdr.get_eff()
    d_byld       = rdr.get_byields()
    d_nent       = rkut.average_byields(d_byld, l_exclude=['TIS'])
    d_rare_yld   = rkut.reso_to_rare(d_nent, kind='jpsi')

    mod         = rk_model(preffix='', d_eff=d_eff, d_nent=d_rare_yld, l_dset=[dset])
    d_mod       = mod.get_model()
    d_val, d_var= mod.get_cons()

    return d_mod, d_val, d_var
#-----------------------------
def fit_dset(trig, dset):
    pars_path = f'{data.out_dir}/{dset}_{trig}.json'
    if os.path.isfile(pars_path):
        log.info(f'Parameters already found for {trig}-{dset}, skipping')
        return

    trg                 = 'TIS' if trig == 'GTIS' else 'TOS'
    key                 = f'{dset}_{trg}'
    d_mod, d_val, d_var = get_model(key)
    mod_mm, mod_ee      = d_mod[key]

    mod          = mod_mm if trig == 'MTOS' else mod_ee
    obj          = normalizer(dset=dset, trig=trig, model=mod, d_val=d_val, d_var=d_var)
    obj.out_dir  = data.out_dir
    res          = obj.get_fit_result()

    d_par = zut.res_to_dict(res, frozen=True)
    utnr.dump_json(d_par, pars_path)
    log.info(f'Saving to: {pars_path}')
#-----------------------------
def main():
    for trig in data.l_trig:
        for dset in data.l_dset:
            log.info(f'Running {trig}-{dset}') 
            d_par = fit_dset(trig, dset)
#-----------------------------
if __name__ == '__main__':
    get_args()
    prepare_output()
    main()

