#!/usr/bin/env python3

import ROOT

import os
import re
import zfit
import numpy
import pprint

from log_store  import log_store
from rk_model   import rk_model  as model
from np_reader  import np_reader as np_rdr
from extractor  import extractor as ext
from cmb_ck     import combiner  as cmb_ck

import zutils.utils        as zut
import rk.utilities        as rkut
import pandas              as pnd
import utils_noroot        as utnr

import argparse

log = log_store.add_logger(name='rk_extractor:fit_toy')
#--------------------------------
class data:
    rseed     = None
    l_fix_var = None
    mod_var   = None
    l_dset    = ['all_TOS']
    rk        = 1.0 
    out_dir   = 'results' 
#--------------------------------
def add_ne(d_pos, d_pre):
    regex='nsg_mm_(.*_TOS_.*)'
    d_pos_ext = {}
    for var_name in d_pos:
        mtch = re.match(regex, var_name)
        if not mtch:
            continue

        nsg_mm_name= mtch.group(0)
        suffix_tos = mtch.group(1)
        suffix_tis = mtch.group(1).replace('_TOS_', '_TIS_')

        d_pos_ext[f'nsg_ee_{suffix_tos}'] = get_ne(suffix_tos, d_pos, d_pre) 
        if f'ck_{suffix_tis}' in d_pos:
            d_pos_ext[f'nsg_ee_{suffix_tis}'] = get_ne(suffix_tis, d_pos, d_pre)
        else:
            log.warning(f'TIS ck not found, skiping electron TIS yield')

    d_pos.update(d_pos_ext)

    return d_pos
#--------------------------------
def get_data(d_nent=None, rseed=None, obs_mm_sp=None): 
    mod      = model(rk=data.rk, preffix='toys_gen', d_nent=d_nent, obs_mm_sp=obs_mm_sp)
    mod.kind = 'nom' 
    d_dat    = mod.get_data(rseed=rseed)
    d_mod    = mod.get_model() 

    print_model(d_mod, name='gen_pdf')

    return d_dat
#--------------------------------
def print_model(d_mod, name=None):
    [mod_bm_mm, mod_qm_mm], mod_ee = d_mod['all']

    zut.print_pdf(mod_bm_mm, txt_path=f'{data.out_dir}/model/bm_mm_{name}.txt')
    zut.print_pdf(mod_qm_mm, txt_path=f'{data.out_dir}/model/qm_mm_{name}.txt')
    zut.print_pdf(mod_ee   , txt_path=f'{data.out_dir}/model/bm_ee_{name}.txt')
#--------------------------------
def get_nentries(df):
    nent_mm = df.sign_MTOS.sum()
    nent_ee = df.sign_ETOS.sum()

    return {'all' : (nent_mm, nent_ee)}
#--------------------------------
def get_ck(df):
    nent_mm = df.sign_MTOS.sum()
    nent_ee = df.sign_ETOS.sum()
    ck      = nent_ee / nent_mm

    return ck
#--------------------------------
def fit(rseed=None, l_fix_var=None, mod_var=None):
    log.info(f'Seed: {rseed:04}')
    log.info(f'Variable fixed : {l_fix_var}')
    log.info(f'Model variation: {mod_var}')

    log.warning(f'Using an error of ck of 1% for toys')
    ck_var       = 0.01 ** 2

    rdr          = np_rdr(sys='v65', sta='v81', yld='v24')
    df_ryld      = rdr.get_ryields()
    d_rare_yld   = get_nentries(df_ryld)
    ck           = get_ck(df_ryld)

    obs_mm_sp    = zfit.Space('mass mm', limits=(2600, 3300))
    mod          = model(rk=data.rk, preffix='toys_fit', d_nent=d_rare_yld, obs_mm_sp=obs_mm_sp)
    mod.kind     = 'nom' if mod_var is None else mod_var
    d_mod        = mod.get_model() 
    d_val, d_var = mod.get_cons() 
    d_pre        = mod.get_prefit_pars(d_var=d_var, ck_var=ck_var)

    if mod_var is None:
        print_model(d_mod, name='gen_pdf')
        d_dat    = mod.get_data(rseed=rseed)
    else:
        d_dat    = get_data(d_nent=d_rare_yld, rseed=rseed, obs_mm_sp=obs_mm_sp)

    suffix       = get_suffix(rseed, l_fix_var, mod_var)

    obj          = ext(drop_correlations=False)
    obj.plt_dir  = f'{data.out_dir}/model/fits_{suffix}'
    obj.ck       = ck 
    obj.cov      = ck_var 
    obj.data     = d_dat
    obj.model    = d_mod 
    obj.fix      = l_fix_var
    obj.const    = d_val, d_var
    result       = obj.get_fit_result()

    log.info(f'Calculating errors')
    result.hesse()
    d_pos = rkut.result_to_dict(result) 
    d_pos = add_ne(d_pos, d_pre)
    result.freeze()

    d_inf = {'pre' : d_pre, 'pos' : d_pos} 


    utnr.dump_pickle(result, f'{data.out_dir}/result_pkl/result_{suffix}.pkl')
    utnr.dump_json(d_inf, f'{data.out_dir}/result_jsn/result_{suffix}.json')
#--------------------------------
def get_suffix(rseed, l_fix_var, mod_var):
    p_seed = f'{rseed:04}'
    
    if l_fix_var is None:
        nfixvar = 0
    else:
        nfixvar= len(l_fix_var)

    p_fvar = f'{nfixvar:04}' 

    if mod_var is None:
        p_mvar = 'nominal'
    else:
        p_mvar = mod_var

    rkstr = f'{data.rk:0.2f}'.replace('.', 'p')

    return f'{p_seed}_{p_fvar}_{p_mvar}_{rkstr}'
#--------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Will run a single fit to toy data')
    parser.add_argument('-r', '--rk'       , type =float, help='Value of rk', default=data.rk) 
    parser.add_argument('-s', '--rseed'    , type =  int, help='Random seed', required=True) 
    parser.add_argument('-f', '--fix_vars' , nargs=  '+', help='Variables to fix if doing systematic studies') 
    parser.add_argument('-d', '--datasets' , nargs=  '+', help='List of datasets to fit', default=data.l_dset)
    parser.add_argument('-m', '--alt_model', type =  str, help='Alternative model for systematic studies'    ) 
    args = parser.parse_args()

    data.rk        = args.rk
    data.rseed     = args.rseed
    data.l_dset    = args.datasets
    data.l_fix_var = args.fix_vars
    data.mod_var   = args.alt_model
#--------------------------------
def main():
    get_args()
    os.makedirs(data.out_dir, exist_ok=True)
    fit(rseed = data.rseed, l_fix_var = data.l_fix_var, mod_var = data.mod_var)
#--------------------------------
if __name__ == '__main__':
    main()

