#!/usr/bin/env python3

import os
import zfit
import glob
import numpy
import pandas              as pnd
import logzero
import tarfile
import rk.utilities        as rkut
import utils_noroot        as utnr

from rkex_model import model
from np_reader  import np_reader as np_rdr
from extractor  import extractor as ext
from logzero    import logger    as log
from zutils     import utils     as zut

#--------------------------------
class data:
    out_dir = 'results' 
    l_seed  = None
#--------------------------------
def fit(rseed):
    log.info(f'Seed: {rseed:04}')

    rdr          = np_rdr(sys='v65', sta='v63', yld='v24')
    rdr.cache    = True
    rdr.cache_dir= 'v65_v63_v24'
    cv_sys       = rdr.get_cov(kind='sys')
    cv_sta       = rdr.get_cov(kind='sta')
    d_eff        = rdr.get_eff()
    d_rjpsi      = rdr.get_rjpsi()
    d_byld       = rdr.get_byields()
    d_nent       = rkut.average_byields(d_byld, l_exclude=['TIS'])
    d_rare_yld   = rkut.reso_to_rare(d_nent, kind='jpsi')

    mod          = model(preffix='real', d_eff=d_eff)
    d_mod        = mod.get_model()
    d_dat        = mod.get_data(d_nent=d_rare_yld, rseed=rseed)

    obj          = ext()
    obj.rjpsi    = d_rjpsi
    obj.eff      = d_eff
    obj.cov      = cv_sys + cv_sta
    obj.data     = d_dat
    obj.model    = d_mod 
    result       = obj.get_fit_result()

    log.info(f'Calculating errors')
    result.hesse()
    d_inf = result_to_dict(result) 
    result.freeze()

    return result, d_inf
#--------------------------------
def initialize():
    log.setLevel(logzero.INFO)
    data.l_seed  = get_seeds()
    os.makedirs(data.out_dir, exist_ok=True)
#--------------------------------
def cleanup_env():
    d_par = zfit.Parameter._existing_params
    l_key = list(d_par.keys())

    for key in l_key:
        del(d_par[key])
#--------------------------------
def trim_val(val):
    if   isinstance(val, str):
        return val
    elif isinstance(val, tuple):
        return trim_val(val[0]), trim_val(val[1])

    return numpy.format_float_positional(val, precision=4, unique=False, fractional=False, trim='k') 
#--------------------------------
def result_to_dict(res):
    d_res = {}
    for par, d_val in res.params.items():
        val = d_val['value']
        try:
            err = d_val['hesse']['error'] 
        except:
            err = -999

        d_res[par.name] = val, err

    d_res['converged'] = res.converged
    d_res['status']    = res.status
    d_res['valid']     = res.valid
    d_res['fmin']      = res.fmin
    d_res['edf']       = res.edm
    #----------------------
    d_res_trim         = {key : trim_val(val) for key, val in d_res.items()}

    l_l_cov_trim = []
    l_l_cov      = res.covariance().tolist()
    for l_cov in l_l_cov:
        l_cov_trim = [ trim_val(val) for val in l_cov ]
        l_l_cov_trim.append(l_cov_trim)

    d_res_trim['cov'] = l_l_cov_trim

    return d_res_trim
#--------------------------------
def main():
    initialize()
    l_res   = []
    for rseed in data.l_seed:
        res, d_inf = fit(rseed)
        utnr.dump_pickle(res, f'{data.out_dir}/result_pkl/result_{rseed:04}.pkl')
        utnr.dump_json(d_inf, f'{data.out_dir}/result_jsn/result_{rseed:04}.json')

        cleanup_env()

    with tarfile.open(f'{data.out_dir}/result_pkl.tar.gz', 'w:gz') as tar:
        tar.add(f'{data.out_dir}/result_pkl', arcname='result_pkl')

    with tarfile.open(f'{data.out_dir}/result_jsn.tar.gz', 'w:gz') as tar:
        tar.add(f'{data.out_dir}/result_jsn', arcname='result_jsn')
#--------------------------------
def get_file_seeds(seed_file):
    l_seed = []
    with open(seed_file) as ifile:
        l_seed = ifile.read().splitlines()

    return l_seed
#--------------------------------
def get_seeds():
    l_seed_file = glob.glob('*.sd')
    l_seed  = []
    for seed_file in l_seed_file:
        l_seed += get_file_seeds(seed_file)

    log.debug(f'Using seeds: {l_seed}')

    l_seed_int = [ int(rseed) for rseed in l_seed ]

    return l_seed_int
#--------------------------------
if __name__ == '__main__':
    main()

