#!/usr/bin/env python3

import warnings
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=DeprecationWarning)

import os
import re
import ROOT
import zfit
import pprint
import argparse
import utils_noroot   as utnr
import zutils.utils   as zut
import read_selection as rs

from importlib.resources import files
from normalizer          import normalizer
from np_reader           import np_reader as np_rdr
from rk                  import utilities as rkut
from rk_model            import rk_model
from log_store           import log_store
from builder             import builder   as cmb_bld

log = log_store.add_logger(name='rk_extractor:cmb_prec_nom')
#-----------------------------
class data:
    l_dset   = ['r1', 'r2p1', '2017', '2018', 'all']
    l_trig   = ['MTOS', 'ETOS', 'GTIS']
    l_shr_par= ['mu', 'lm']

    blind   = None 
    dset    = None
    trig    = None

    version = None
    out_dir = None
    cmb_rep = None
#-----------------------------
def prepare_output():
    out_dir = files('extractor_data').joinpath(f'sb_fits/{data.version}')

    os.makedirs(out_dir, exist_ok=True)

    data.out_dir = out_dir
#-----------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Used to calculate normalizations for combinatorial and PRec from data sidebands')
    parser.add_argument('-v', '--version' ,  type=str, help='Version of output, used to name directories', required=True)
    parser.add_argument('-t', '--trigger' ,  type=str, help='Trigger', choices=data.l_trig)
    parser.add_argument('-d', '--dataset' ,  type=str, help='Dataset', choices=data.l_dset)
    parser.add_argument('-B', '--blind'   ,  type=bool,help='Will blind SR or not', default=False, choices=[True, False])
    parser.add_argument('-s', '--shr_par' , nargs='+', help='Parameters shared between bins', choices=data.l_shr_par, default=[])
    parser.add_argument('-r', '--cmb_rep' ,  type=int, help='Reparametrize combinatorial yield with linear dependence across bins', choices=[0, 1], default=0)
    args = parser.parse_args()
    
    data.version   = args.version
    data.dset      = args.dataset
    data.blind     = args.blind
    data.trig      = args.trigger
    data.cmb_rep   = args.cmb_rep
    data.l_shr_par = args.shr_par
#-----------------------------
def get_channel():
    if   data.trig == 'MTOS':
        chan = 'mm'
    elif data.trig in ['ETOS', 'GTIS']:
        chan = 'ee'
    else:
        log.error(f'Wrong trigger value: {data.trig}')
        raise

    return chan
#-----------------------------
def get_nentries(df):
    nent_mm = df.sign_MTOS.sum()
    nent_ee = df.sign_ETOS.sum()

    return {'all' : (nent_mm, nent_ee)}
#-----------------------------
def get_model():
    rdr     = np_rdr(sys='v65', sta='v81', yld='v24')
    df_ryld = rdr.get_ryields()
    d_nent  = get_nentries(df_ryld)

    chan    = get_channel()
    obs_mm  = zfit.Space('mass mm', limits=(2600, 3300))

    mod     = rk_model(
            preffix   = 'sb_fits',
            d_nent    = d_nent,
            channel   = chan,
            obs_mm_sp = obs_mm,
            )

    d_mod = mod.get_model()

    return d_mod
#-----------------------------
def main():
    pars_path = f'{data.out_dir}/{data.dset}_{data.trig}.json'
    if os.path.isfile(pars_path):
        log.info(f'Parameters already found for {data.trig}-{data.dset}, skipping')
        return

    d_mod        = get_model()
    obj          = normalizer(dset=data.dset, trig=data.trig, d_model=d_mod, d_val={}, d_var={}, blind=data.blind)
    obj.out_dir  = data.out_dir
    res          = obj.get_fit_result()

    d_par = zut.res_to_dict(res, frozen=True)
    d_par = {key : val for key, val in d_par.items() if not key.startswith('nsg_')}
    utnr.dump_json(d_par, pars_path)
    log.info(f'Saving to: {pars_path}')
#-----------------------------
if __name__ == '__main__':
    get_args()
    prepare_output()
    main()

