#!/usr/bin/env python3

import ROOT

import os
import re
import numpy

from log_store  import log_store
from rk_model   import rk_model  as model
from np_reader  import np_reader as np_rdr
from extractor  import extractor as ext
from cmb_ck     import combiner  as cmb_ck

import rk.utilities        as rkut
import pandas              as pnd
import utils_noroot        as utnr

import argparse

log = log_store.add_logger(name='rk_extractor:fit_toy')
#--------------------------------
class data:
    rseed     = None
    l_fix_var = None
    mod_var   = None
    l_dset    = ['all_TOS']
    rk        = 1.0 
    out_dir   = 'results' 
    bdt_bin   = 5
#--------------------------------
def add_ne(d_pos, d_pre):
    regex='nsg_mm_(.*_TOS_.*)'
    d_pos_ext = {}
    for var_name in d_pos:
        mtch = re.match(regex, var_name)
        if not mtch:
            continue

        nsg_mm_name= mtch.group(0)
        suffix_tos = mtch.group(1)
        suffix_tis = mtch.group(1).replace('_TOS_', '_TIS_')

        d_pos_ext[f'nsg_ee_{suffix_tos}'] = get_ne(suffix_tos, d_pos, d_pre) 
        if f'ck_{suffix_tis}' in d_pos:
            d_pos_ext[f'nsg_ee_{suffix_tis}'] = get_ne(suffix_tis, d_pos, d_pre)
        else:
            log.warning(f'TIS ck not found, skiping electron TIS yield')

    d_pos.update(d_pos_ext)

    return d_pos
#--------------------------------
def fit(rseed=None, l_fix_var=None, mod_var=None):
    log.info(f'Seed: {rseed:04}')
    log.info(f'Variable fixed : {l_fix_var}')
    log.info(f'Model variation: {mod_var}')

    rdr          = np_rdr(sys='v65', sta='v63', yld='v24')
    cv_sys       = rdr.get_cov(kind='sys')
    cv_sta       = rdr.get_cov(kind='sta')
    d_eff        = rdr.get_eff()
    d_rjpsi      = rdr.get_rjpsi()
    d_byld       = rdr.get_byields()
    d_nent       = rkut.average_byields(d_byld, l_exclude=['TIS'])
    d_rare_yld   = rkut.reso_to_rare(d_nent, kind='jpsi')

    mod          = model(rk=data.rk, preffix='toys_fit', d_eff=d_eff, d_nent=d_rare_yld, l_dset=data.l_dset)
    mod.bdt_bin  = data.bdt_bin
    mod.kind     = 'nom' if mod_var is None else mod_var
    d_mod        = mod.get_model() 
    d_val, d_var = mod.get_cons() 
    d_pre        = mod.get_prefit_pars(d_var=d_var, ck_cov=cv_sys+cv_sta)

    if mod_var is None:
        d_dat    = mod.get_data(rseed=rseed)
    else:
        d_dat    = get_data(d_eff=d_eff, d_nent=d_rare_yld, rseed=rseed)

    if data.l_dset == ['all_TOS'] or data.l_dset == ['all_TOS', 'all_TIS']:
        cmb                 = cmb_ck(rk=data.rk, eff=d_eff, yld=d_rare_yld)
        cmb.out_dir         = 'plots/combination'
        t_comb              = cmb.get_combination()
        d_rjpsi, d_eff, cov = t_comb
    else:
        cov = cv_sys + cv_sta

    if   data.l_dset == ['all_TOS']:
        cov = numpy.array([[cov[0][0]]])
    elif data.l_dset == ['all_TIS']:
        cov = numpy.array([[cov[1][1]]])

    obj          = ext(dset=data.l_dset, drop_correlations=False)
    obj.plt_dir  = f'plots/fits_{rseed:03}'
    obj.rjpsi    = d_rjpsi
    obj.eff      = d_eff
    obj.data     = d_dat
    obj.model    = d_mod 
    obj.fix      = l_fix_var
    obj.cov      = cov 
    obj.const    = d_val, d_var
    result       = obj.get_fit_result()

    log.info(f'Calculating errors')
    result.hesse()
    d_pos = rkut.result_to_dict(result) 
    d_pos = add_ne(d_pos, d_pre)
    result.freeze()

    d_inf = {'pre' : d_pre, 'pos' : d_pos} 

    suffix = get_suffix(rseed, l_fix_var, mod_var)

    utnr.dump_pickle(result, f'{data.out_dir}/result_pkl/result_{suffix}.pkl')
    utnr.dump_json(d_inf, f'{data.out_dir}/result_jsn/result_{suffix}.json')
#--------------------------------
def get_suffix(rseed, l_fix_var, mod_var):
    p_seed = f'{rseed:04}'
    
    if l_fix_var is None:
        nfixvar = 0
    else:
        nfixvar= len(l_fix_var)

    p_fvar = f'{nfixvar:04}' 

    if mod_var is None:
        p_mvar = 'nominal'
    else:
        p_mvar = mod_var

    return f'{p_seed}_{p_fvar}_{p_mvar}'
#--------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Will run a single fit to toy data')
    parser.add_argument('-r', '--rk'       , type =float, help='Value of rk', default=data.rk) 
    parser.add_argument('-s', '--rseed'    , type =  int, help='Random seed', required=True) 
    parser.add_argument('-b', '--bdt_bin'  , type =  int, help='BDT bin', default=data.bdt_bin) 
    parser.add_argument('-f', '--fix_vars' , nargs=  '+', help='Variables to fix if doing systematic studies') 
    parser.add_argument('-d', '--datasets' , nargs=  '+', help='List of datasets to fit', default=data.l_dset)
    parser.add_argument('-m', '--alt_model', type =  str, help='Alternative model for systematic studies'    ) 
    args = parser.parse_args()

    data.rk        = args.rk
    data.rseed     = args.rseed
    data.l_dset    = args.datasets
    data.l_fix_var = args.fix_vars
    data.bdt_bin   = args.bdt_bin
    data.mod_var   = args.alt_model
#--------------------------------
def main():
    get_args()
    os.makedirs(data.out_dir, exist_ok=True)
    fit(rseed = data.rseed, l_fix_var = data.l_fix_var, mod_var = data.mod_var)
#--------------------------------
if __name__ == '__main__':
    main()

