#!/usr/bin/env python3

import os
import ROOT
import zfit
import glob
import numpy
import argparse
import pandas              as pnd
import logzero
import tarfile
import rk.utilities        as rkut
import utils_noroot        as utnr

from rkex_model import model
from np_reader  import np_reader as np_rdr
from mc_reader  import mc_reader as mc_rdr
from cs_reader  import cs_reader as cs_rdr
from extractor  import extractor as ext
from logzero    import logger    as log
from zutils     import utils     as zut

#--------------------------------
class data:
    out_dir = 'results' 
    l_seed  = None
    l_dset  = None
    log_lvl = None 
#--------------------------------
def _add_ne(d_pos):
    return d_pos
#--------------------------------
def fit(rseed):
    log.info(f'Seed: {rseed:04}')

    preffix = 'toys' 

    rdr_mc          = mc_rdr(version='v4', real_data=False)
    rdr_mc.cache    = True 
    rdr_mc.cache_dir= 'v4_mcrdr'
    d_mcmu          = rdr_mc.get_parameter(name='mu')
    d_mcsg          = rdr_mc.get_parameter(name='sg')

    rdr_dt          = mc_rdr(version='v4', real_data=True)
    rdr_dt.cache    = True 
    rdr_dt.cache_dir= 'v4_dtrdr' 
    d_dtmu          = rdr_dt.get_parameter(name='mu')
    d_dtsg          = rdr_dt.get_parameter(name='sg')

    rdr          = np_rdr(sys='v65', sta='v63', yld='v24')
    rdr.cache    = True
    rdr.cache_dir= 'v65_v63_v24'
    cv_sys       = rdr.get_cov(kind='sys')
    cv_sta       = rdr.get_cov(kind='sta')
    d_eff        = rdr.get_eff()
    d_rjpsi      = rdr.get_rjpsi()
    d_byld       = rdr.get_byields()
    d_nent       = rkut.average_byields(d_byld, l_exclude=['TIS'])
    d_rare_yld   = rkut.reso_to_rare(d_nent, kind='jpsi')

    mod          = model(preffix=preffix, d_eff=d_eff,  d_mcmu=d_mcmu, d_mcsg=d_mcsg, d_nent=d_rare_yld, d_dtmu=d_dtmu, d_dtsg=d_dtsg)
    d_mod        = mod.get_model()
    d_dat        = mod.get_data(rseed=rseed)
    d_pre        = mod.get_prefit_pars()

    rdr          = cs_rdr(version='v4', preffix=preffix)
    rdr.cache    = True 
    rdr.cache_dir= 'v4_csrdr'
    d_val, d_var = rdr.get_constraints()

    obj          = ext(dset=data.l_dset, drop_correlations=False)
    obj.plt_dir  = f'plots/fits_{rseed:03}'
    obj.rjpsi    = d_rjpsi
    obj.eff      = d_eff
    obj.cov      = cv_sys + cv_sta
    obj.data     = d_dat
    obj.model    = d_mod 
    obj.const    = d_val, d_var
    result       = obj.get_fit_result()

    log.info(f'Calculating errors')
    result.hesse()
    d_pos = rkut.result_to_dict(result) 
    d_pos = self._add_ne(d_pos)
    result.freeze()

    return result, {'pre' : d_pre, 'pos' : d_pos} 
#--------------------------------
def initialize():
    log.setLevel(data.log_lvl)
    data.l_seed  = get_seeds()
    os.makedirs(data.out_dir, exist_ok=True)
#--------------------------------
def cleanup_env():
    d_par = zfit.Parameter._existing_params
    l_key = list(d_par.keys())

    for key in l_key:
        del(d_par[key])
#--------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Used run toy fits on model used to extract RK')
    parser.add_argument('-l', '--level' , type =int, help='Logging level', choices=[logzero.DEBUG, logzero.INFO, logzero.WARNING], default=logzero.INFO)
    parser.add_argument('-d', '--dset'  , nargs='+', help='Datasets to use') 
    args = parser.parse_args()

    data.log_lvl = args.level
    data.l_dset  = args.dset
#--------------------------------
def main():
    get_args()
    initialize()
    l_res   = []
    for rseed in data.l_seed:
        res, d_inf = fit(rseed)
        print(res)
        utnr.dump_pickle(res, f'{data.out_dir}/result_pkl/result_{rseed:04}.pkl')
        utnr.dump_json(d_inf, f'{data.out_dir}/result_jsn/result_{rseed:04}.json')

        cleanup_env()

    with tarfile.open(f'{data.out_dir}/result_pkl.tar.gz', 'w:gz') as tar:
        tar.add(f'{data.out_dir}/result_pkl', arcname='result_pkl')

    with tarfile.open(f'{data.out_dir}/result_jsn.tar.gz', 'w:gz') as tar:
        tar.add(f'{data.out_dir}/result_jsn', arcname='result_jsn')
#--------------------------------
def get_file_seeds(seed_file):
    l_seed = []
    with open(seed_file) as ifile:
        l_seed = ifile.read().splitlines()

    return l_seed
#--------------------------------
def get_seeds():
    l_seed_file = glob.glob('*.sd')
    l_seed  = []
    for seed_file in l_seed_file:
        l_seed += get_file_seeds(seed_file)

    if len(l_seed) == 0:
        log.error(f'No seeds found')
        raise

    log.debug(f'Using seeds: {l_seed}')

    l_seed_int = [ int(rseed) for rseed in l_seed ]

    return l_seed_int
#--------------------------------
if __name__ == '__main__':
    main()

