#!/usr/bin/env python3

import os
import re
import ROOT
import zfit
import glob
import math
import toml
import numpy
import pprint
import logzero
import tarfile
import argparse
import subprocess
import jacobi              as jac

from importlib.resources import files
from pprint     import pformat
from logzero    import logger    as log
from zutils     import utils     as zut
from log_store  import log_store

#--------------------------------
class data:
    bdt_bin = 5
    out_dir = 'results' 
    rk      = 1
    l_seed  = None
    l_dset  = None
    l_rkval = None
    l_fixed = None
    l_model = None

    log_lvl = None 
    for_syst= None
    for_pull= None
#--------------------------------
def get_ne_args(d_pos, d_pre, ck_name, nsg_mm_name):
    nsg_mm, _   = d_pos[nsg_mm_name]
    rk, _       = d_pos['rk']

    if ck_name not in d_pos:
        ck, _ = d_pre[ck_name]
        l_par = [nsg_mm_name, 'rk']
    else:
        ck, _ = d_pos[ck_name]
        l_par = [ck_name, nsg_mm_name, 'rk']

    l_val       = [ck, nsg_mm, rk]
    cov         = d_pos['cov']
    cov         = numpy.array(cov)
    cov         = cov.astype(float)

    l_par_excl  = [ par                     for par in d_pos['par'] if par not in l_par ]
    l_ind_excl  = [ d_pos['par'].index(par) for par in l_par_excl                       ]
    cov         = numpy.delete(cov, l_ind_excl, axis=0)
    cov         = numpy.delete(cov, l_ind_excl, axis=1)

    if ck_name not in d_pos:
        cov = numpy.pad(cov, (1, 0))

    return l_val, cov
#--------------------------------
def get_ne(suffix, d_pos, d_pre):
    suffix_tos  = suffix.replace('_TIS_', '_TOS_')
    ck_name     = f'ck_{suffix}'
    nsg_mm_name = f'nsg_mm_{suffix_tos}'

    l_val, cov  = get_ne_args(d_pos, d_pre, ck_name, nsg_mm_name)

    nsg_ee_val, nsg_ee_var = jac.propagate(lambda x : (x[0] * x[1]) / x[2], l_val, cov ) 
    nsg_ee_err = math.sqrt(nsg_ee_var)

    nsg_ee_val = float(nsg_ee_val)
    nsg_ee_err = float(nsg_ee_err)

    return [nsg_ee_val, nsg_ee_err]
#--------------------------------
def initialize():
    log_store.log_level     = data.log_lvl
    data.l_seed             = get_seeds()
    check_job_kind()
    os.makedirs(data.out_dir, exist_ok=True)
#--------------------------------
def check_job_kind():
    nseed = len(data.l_seed)
    nfix  = len(data.l_fixed)
    nmod  = len(data.l_model)

    data.for_syst = nseed == 1 and (nfix >  0 or  nmod  > 0)
    data.for_pull = nseed >= 1 and  nfix == 0 and nmod == 0

    if   data.for_syst:
        log.info(f'Running systematics job')
    elif data.for_pull:
        log.info(f'Running pulls job')
    else:
        log.error(f'Misconfigured job, seeds/fixed = {nseed}/{nfix}')
        log.info('-------------')
        log.info('Valid:')
        log.info('-------------')
        log.info('nSeed=1; nFix > 0 or  nMod > 0')
        log.info('nSeed>1; nFix = 0 and nMod = 0')
        raise
#--------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Used run toy fits on model used to extract RK')
    parser.add_argument('-l', '--level' , type=int, help='Logging level', choices=[logzero.DEBUG, logzero.INFO, logzero.WARNING], default=logzero.INFO)
    parser.add_argument('-v', '--vers'  , type=str, help='Version of configuration', required=True) 
    args = parser.parse_args()

    data.log_lvl = args.level

    read_config(args.vers)
#--------------------------------
def read_config(version):
    config_path = files('extractor_data').joinpath(f'config/{version}.toml')
    cfg         = toml.load(config_path)
    data.l_dset = cfg['input']['datasets']
    data.l_rkval= cfg['input']['rk_val'  ]
    data.l_fixed= cfg['systematics']['fix_var']
    l_sig_mod   = cfg['systematics']['sig_mod']
    l_rpr_mod   = extract_models(cfg, 'rpr_mod')
    l_cpr_mod   = extract_models(cfg, 'cpr_mod')
    data.l_model= l_sig_mod + l_rpr_mod + l_cpr_mod

    log.debug('-' * 20)
    log.info('Reading configuration')
    log.debug('-' * 20)
    log.debug(f'{"Datasets":<20}\n'   + pformat(data.l_dset))
    log.debug(f'{"Fixed vars":<20}\n' + pformat(data.l_fixed))
    log.debug(f'{"Models":<20}\n'     + pformat(data.l_model))
    log.debug(f'{"RK values":<20}\n'  + pformat(data.l_rkval))
    log.debug('-' * 20)
#--------------------------------
def extract_models(cfg, kind):
    st_mod = cfg['systematics'][kind]
    rgx = '(\w{3})_(MTOS|ETOS|GTIS):bts_(\d+)_(\d+)'
    mtc = re.match(rgx, st_mod)
    if not mtc:
        log.error(f'Invalid model setting: {st_mod}')
        raise

    [mod, trg, ini, fnl] = mtc.groups()

    #Turn off bootstrapping if for this model
    #the initial and final bootstrapping random seed
    #are equal, e.g: rpr_ETOS:bts_1_1
    if ini == fnl:
        return []

    ini = int(ini)
    fnl = int(fnl)

    return [f'{mod}_{trg}:bts{val}' for val in range(ini, fnl + 1)] 
#--------------------------------
def print_args():
    log.info('-' * 40)
    log.info(f'Args for {__file__}:')
    log.info('-' * 40)
    log.info(f'{"Level":<20}{data.log_lvl}')
    log.info(f'{"Datasets":<20}{data.l_dset}')
    log.info(f'{"Vars fixed":<20}{data.l_fixed}')
    log.info(f'{"Models":<20}{data.l_model}')
    log.info('-' * 40)
#--------------------------------
def run_command(command, arguments=None):
    if not isinstance(arguments, list):
        log.error(f'Invalid options argument: {arguments}')
        raise ValueError

    log.info('-' * 30)
    log.info('-' * 30)
    log.info(f'{command:<10}{str(arguments):<50}')
    log.info('-' * 30)
    log.info('-' * 30)

    stat = subprocess.run([command] + arguments)

    if stat.returncode != 0:
        log.error(f'Process returned exit status: {stat.returncode}')
        raise
#--------------------------------
def run_pull_fits():
    if not data.for_pull:
        return

    for rseed in data.l_seed:
        for rkval in data.l_rkval:
            rkval = str(rkval)
            run_command('fit_toys', arguments=['-s', rseed, '-r', rkval])
#--------------------------------
def run_syst_fits():
    if not data.for_syst:
        return

    [rseed] = data.l_seed
    run_command('fit_toys', arguments=['-s', rseed])

    l_fix_var = []
    for fix_var in data.l_fixed:
        l_fix_var.append(fix_var)
        run_command('fit_toys', arguments=['-s', rseed, '-f'] + l_fix_var)

    for mod_var in data.l_model:
        run_command('fit_toys', arguments=['-s', rseed, '-m', mod_var])
#--------------------------------
def main():
    get_args()
    print_args()
    initialize()
    run_pull_fits()
    run_syst_fits()

    with tarfile.open(f'{data.out_dir}/result_pkl.tar.gz', 'w:gz') as tar:
        tar.add(f'{data.out_dir}/result_pkl', arcname='result_pkl')

    with tarfile.open(f'{data.out_dir}/result_jsn.tar.gz', 'w:gz') as tar:
        tar.add(f'{data.out_dir}/result_jsn', arcname='result_jsn')
#--------------------------------
def get_file_seeds(seed_file):
    l_seed = []
    with open(seed_file) as ifile:
        l_seed = ifile.read().splitlines()

    return l_seed
#--------------------------------
def get_seeds():
    l_seed_file = glob.glob('*.sd')
    l_seed  = []
    for seed_file in l_seed_file:
        l_seed += get_file_seeds(seed_file)

    if len(l_seed) == 0:
        log.error(f'No seeds found')
        raise

    log.debug(f'Using seeds: {l_seed}')

    return l_seed
#--------------------------------
if __name__ == '__main__':
    main()

