#!python

import os
import array

import ROOT
import logzero

from logzero import logger as log

from importlib.resources import files

#-------------------------------
class H2V:
    def __init__(self, name):
        self._h_poly = ROOT.TH2Poly()

        self._h_poly.SetName(name)
        self._h_poly.SetTitle('')
    #---------------------------
    def add_slice(self, hist, min_y, max_y):
        nbins = hist.GetNbinsX()

        for i_bin in range(1, nbins + 1):
            min_x = hist.GetBinLowEdge(i_bin + 0)
            max_x = hist.GetBinLowEdge(i_bin + 1)
            bin_c = hist.GetBinContent(i_bin)

            i_poly_bin = self._h_poly.AddBin(min_x, min_y, max_x, max_y)
            self._h_poly.SetBinContent(i_poly_bin, bin_c)
    #---------------------------
    def get_poly(self):
        return self._h_poly
#-------------------------------
class data:
    outdir = None
#-------------------------------
def check_array(org_lst):
    srt_lst = sorted(org_lst)

    if org_lst != srt_lst:
        print(org_lst)
        print(srt_lst)
        raise
#-------------------------------
def getHistogram(xname, yname, l_l_y_x, l_y):
    obj=H2V("h_poly")
    for index in range(0, len(l_y) - 1):
        lo_y=l_y[index + 0]
        hi_y=l_y[index + 1]

        l_x=l_l_y_x[index]
        check_array(l_x)
        nbins=len(l_x) - 1
        arr=array.array('f', l_x)
        h=ROOT.TH1F("h_{}".format(index), "", nbins, arr)
        obj.add_slice(h, lo_y, hi_y)

    hist=obj.get_poly()
    hist.GetXaxis().SetTitle(xname)
    hist.GetYaxis().SetTitle(yname)

    return hist
#-------------------------------
def fillHistogram(h_poly):
    l_bin = h_poly.GetBins()
    for bin_obj in l_bin:
        i_bin = bin_obj.GetBinNumber()
        bin_obj.SetContent(i_bin)

    return h_poly
#-------------------------------
def makeMap(CHANNEL, TRIGGER, YEAR, version):
    log.debug('Making map for {}/{}/{}'.format(CHANNEL, TRIGGER, YEAR))

    l_yr_1 = [2011, 2012]
    l_yr_2 = [2015, 2016, 2017, 2018]
    l_yr_l = [2011, 2015]
    l_yr_h = [2012, 2016, 2017, 2018]

    epsilon = 1e-5

    if   CHANNEL == "muon"     and TRIGGER == "L0" :
        ofilename="{}/L0Muon_{}_{}.root".format(data.outdir, "NA", YEAR)
        VAR_X="#mu(p_{T})[MeV]"
        VAR_Y="#mu(#eta)"

        if   version == 'v0':
            r_x_1=range(0, 20000 + 100, 100)
            l_x_1=list(r_x_1)
        elif YEAR in l_yr_l and version == 'v1':
            l_x_1=[0, 1585, 1884, 2077, 2226, 2353, 2464, 2565, 2661, 2756, 2852, 2949, 3048, 3150, 3256, 3366, 3478, 3592, 3711, 3833, 3959, 4092, 4230, 4380, 4541, 4717, 4912, 5130, 5378, 5669, 6004, 6423, 6950, 7640, 8613, 10365, 20000]
            l_x_2=[0, 2599, 3869, 20000]
        elif YEAR in l_yr_l and version in ['v2', 'v3', 'v4', 'v5']:
            l_x_1=[0, 1285, 1584, 1777, 1826, 2053, 2164, 2365, 2461, 2556, 2652, 2749, 2848, 2950, 3056, 3166, 3278, 3392, 3511, 3733, 3959, 4192, 4330, 4480, 4601, 4817, 5012, 5230, 5478, 5669, 6104, 6523, 7050, 7740, 8913, 12365, 20000]
            l_x_2=[0, 3099, 4069, 20000]
        elif YEAR in l_yr_h and version == 'v1':
            l_x_1=[0, 1384, 1571, 1695, 1794, 1882, 1963, 2041, 2118, 2194, 2270, 2347, 2425, 2505, 2586, 2672, 2760, 2853, 2954, 3063, 3181, 3311, 3455, 3610, 3783, 3978, 4191, 4433, 4704, 5022, 5407, 5885, 6518, 7422, 9039, 20000]
            l_x_2=[0, 1882, 2518, 3562, 20000]
        elif YEAR in l_yr_h and version in ['v2', 'v3', 'v4', 'v5']:
            l_x_1=[0, 1584, 1871, 1995, 2094, 2182, 2263, 2341, 2518, 2794, 2870, 3047, 3225, 3405, 3686, 3872, 4060, 4253, 4554, 4663, 4881, 4911, 5255, 5410, 5683, 5878, 6191, 6333, 6804, 7022, 7307, 7685, 8018, 9022, 9939, 20000]
            l_x_2=[0, 1082, 2018, 3062, 5000, 20000]
        else:
            log.error(f'Invalid year/version: {YEAR}/{version}')
            raise ValueError

        if version == 'v0':
            l_y = [1.5, 5]
            l_l_y_x=[l_x_1]
        else:
            l_y = [1.5, 4, 5]
            l_l_y_x=[l_x_1, l_x_2]
    elif CHANNEL == "electron_fac" and TRIGGER == "L0" :
        ofilename="{}/L0ElectronFAC_{}_{}.root".format(data.outdir,"NA", YEAR)
        VAR_X="Max(E_{T}(e^{+}), E_{T}(e^{-}))"
        VAR_Y="r(e^{+},e^{-})"

        if version == 'v0':
            r_x = range(0, 10000 + 500, 500) 
            l_x = list(r_x)

            r_y = range(0, 10000 + 500, 500) 
            l_y = list(r_y)
        else:
            l_x = [0, 1000, 2200, 2400, 2600, 2800, 3000, 3200,  3400, 3800, 4000, 10000]
            l_y = [0, 750, 1200, 1500, 1800, 2300, 3000, 3800, 10000]

        nbinsy = len(l_y) - 1
        l_l_y_x=[l_x] * nbinsy
    elif CHANNEL == "electron" and TRIGGER == "L0" :
        ofilename="{}/L0Electron_{}_{}.root".format(data.outdir,"NA", YEAR)
        VAR_X="e_{probe}^{L0}(E_{T}^{ECAL})"
        VAR_Y="e_{probe}^{L0}(Region)"

        l_y = [-1 - epsilon, 0 - epsilon, 1 - epsilon, 2 - epsilon, 3 - epsilon]
        if   version == 'v0':
            r_x=range(0, 10000 + 50, 50)

            l_x_1=list(r_x)
            l_x_2=list(r_x)
            l_x_3=list(r_x)
        elif YEAR in [2011, 2015] and version == 'v1':
            l_x_1=[0, 2000, 2375, 2750, 3125, 3500, 3830, 4102, 4423, 4820, 5333, 6062, 7231, 10000]
            l_x_2=[0, 2000, 2375, 2750, 3125, 3500, 3856, 4170, 4555, 5049, 5740, 6906, 10000]
            l_x_3=[0, 2000, 2375, 2750, 3125, 3500, 3920, 4282, 4745, 5408, 6508, 10000]
        elif YEAR in [2011, 2015] and version == 'v2':
            l_x_1=[0, 2000, 2300, 2600, 2900, 3300, 3500, 3830, 4102, 4423, 4820, 5333, 7231, 10000]
            l_x_2=[0, 2000, 2300, 2600, 2900, 3300, 3500, 3800, 4170, 5049, 5740, 6906, 10000]
            l_x_3=[0, 2000, 2300, 2600, 2900, 3300, 3500, 3800, 4170, 5049, 5740, 6906, 10000]
        elif version == 'v1':
            l_x_1=[0, 2000, 2300, 2600, 2900, 3200, 3500, 3676, 4068, 4577, 5304, 6518, 10000]
            l_x_2=[0, 2000, 2300, 2600, 2900, 3200, 3500, 3853, 4356, 5057, 6268, 10000]
            l_x_3=[0, 2000, 2300, 2600, 2900, 3200, 3500, 4110, 4764, 5864, 10000]
        elif version in ['v2', 'v3', 'v4', 'v5']:
            l_x_1=[0, 2000, 2200, 2400, 2800, 3000, 3276, 3500, 4068, 4377, 5004, 6018, 8000, 10000]
            l_x_2=[0, 2000, 2200, 2400, 2800, 3000, 3253, 3500, 4056, 4557, 5268, 6000, 10000]
            l_x_3=[0, 2000, 2200, 2400, 2800, 3000, 3210, 3500, 4064, 5064, 7000, 10000]
        else:
            raise

        l_l_y_x=[l_x_1, l_x_1, l_x_2, l_x_3]
    elif CHANNEL == "hadron"   and TRIGGER == "L0" :
        ofilename="{}/L0Hadron_{}_{}.root".format(data.outdir,"NA", YEAR)
        VAR_X="K^{L0}(E_{T}^{HCAL})"
        VAR_Y="K^{L0}(Region)"

        l_y = [-1 - epsilon, 0 - epsilon, 1 - epsilon, 2 - epsilon]

        if   version == 'v0':
            r_x   = range(0, 10000 + 50, 50)
            l_x_1 = list(r_x)
            l_x_2 = list(r_x)
        elif version in ['v1', 'v2', 'v3', 'v4', 'v5']:
            l_x_1=[0, 2000, 3000, 4000, 4250, 4500, 4750, 5000, 5500, 6000, 6500, 7000, 8000, 10000, 15000, 25000]
            l_x_2=[0, 2000, 3000, 4000, 4250, 4500, 4750, 5000, 5500, 6000, 6500, 7000, 8000, 10000, 15000, 25000]
        else:
            raise

        l_l_y_x = [l_x_1, l_x_1, l_x_2]
    elif CHANNEL == "gem"      and TRIGGER == "L0" :
        ofilename="{}/L0GEM_{}_{}.root".format(data.outdir,"NA", YEAR)
        VAR_X="max(e^{+}(p_{T}), e^{-}(p_{T}))"
        VAR_Y="None"
    
        l_y = [0, 8]
        if   version == 'v0':
            r_x = range(0, 30000 + 100, 100)
            l_x = list(r_x) 
        elif YEAR in [2011, 2012, 2015] and version == 'v1':
            l_x = [0, 2000, 4000, 5000, 6000, 6500, 7000, 8000, 9000, 10000, 30000] 
        elif YEAR in [2011, 2012, 2015] and version == 'v2':
            l_x = [0, 1500, 3000, 3500, 4500, 5500, 6000, 7500, 9500, 11000, 30000] 
        elif version == 'v1':
            l_x = [0, 2252, 2964, 3635, 4328, 5113, 6155, 6881, 7941, 9760, 30000] 
        elif version in ['v2', 'v3', 'v4', 'v5']:
            l_x = [0, 2052, 2564, 3035, 4028, 5013, 5555, 6081, 7041, 8060, 10000, 30000] 
        else:
            raise

        l_l_y_x = [l_x]
    elif CHANNEL == "gmh"      and TRIGGER == "L0" :
        ofilename="{}/L0GMH_{}_{}.root".format(data.outdir,"NA", YEAR)
        VAR_X="B(p_{T})"
        VAR_Y="None"
    
        l_y = [0, 8]
        if   version == 'v0':
            r_x = range(0, 100000 + 500, 500)
            l_x = list(r_x) 
        elif YEAR in [2011, 2015] and version == 'v1':
            l_x = [0, 3862, 6017, 8417, 12198, 100000] 
        elif YEAR in [2011, 2015] and version == 'v2':
            l_x = [0, 3062, 5017, 6417, 10198, 15000, 100000] 
        elif version == 'v1':
            l_x = [0, 3823, 5949, 8289, 11971, 100000]
        elif version in ['v2', 'v3', 'v4', 'v5']:
            l_x = [0, 3023, 5049, 6289, 8000, 10071, 13000, 100000]
        else:
            raise

        l_l_y_x = [l_x]
    elif CHANNEL == "gbn"      and TRIGGER == "L0" :
        ofilename="{}/L0GBN_{}_{}.root".format(data.outdir,"NA", YEAR)
        VAR_X="max(e^{+}(p_{T}), e^{-}(p_{T}))"
        VAR_Y="None"
    
        l_y = [0, 8]
        if   version == 'v0':
            r_x = range(0, 30000 + 100, 100)
            l_x = list(r_x) 
        elif YEAR in [2011, 2012, 2015] and version == 'v1':
            l_x = [0, 2000, 4000, 5000, 6000, 6500, 7000, 8000, 9000, 10000, 30000] 
        elif YEAR in [2011, 2012, 2015] and version == 'v2':
            l_x = [0, 1500, 3000, 3500, 4500, 5500, 6000, 7500, 9500, 11000, 30000] 
        elif version == 'v1':
            l_x = [0, 2252, 2964, 3635, 4328, 5113, 6155, 6881, 7941, 9760, 30000] 
        elif version in ['v2', 'v3', 'v4']:
            l_x = [0, 2052, 2564, 3035, 4028, 5013, 5555, 6081, 7041, 8060, 10000, 30000] 
        elif version == 'v5': 
            l_x = [0, 1052, 2364, 2835, 3528, 4513, 5355, 5781, 6541, 7560,  9000, 20000, 30000] 
        else:
            raise

        l_l_y_x = [l_x]
    #----------------------------------------------------------------
    elif CHANNEL == "muon"     and TRIGGER == "HLT":
        ofilename="{}/HLTMuon_{}_{}.root".format(data.outdir,"NA", YEAR)
        VAR_X="B(p_{T})[MeV]"
        VAR_Y="B(#eta)"

        l_y = [0, 8]
        if version not in ['v4', 'v5']:
            l_x = [0, 7000, 10000, 14000, 40000] 
        else:
            l_x = [0, 88, 115, 135, 150, 175, 200, 230, 380] 
    
        l_l_y_x = [l_x]
    elif CHANNEL == "electron" and TRIGGER == "HLT":
        ofilename="{}/HLTElectron_{}_{}.root".format(data.outdir,"NA", YEAR)
    
        VAR_X="B(p_{T})[MeV]"
        VAR_Y="B(#eta)"

        l_y = [0, 8]
        if version not in ['v4', 'v5']:
            l_x = [0, 7000, 10000, 14000, 40000] 
        else:
            l_x = [0, 95, 125, 150, 175, 210, 380]
    
        l_l_y_x = [l_x]
    elif CHANNEL == "hadron"   and TRIGGER == "HLT":
        ofilename="{}/HLTHadron_{}_{}.root".format(data.outdir,"NA", YEAR)
    
        VAR_X="B(p_{T})[MeV]"
        VAR_Y="B(#eta)"

        l_y = [0, 8]
        if version not in ['v4', 'v5']:
            l_x = [0, 8000, 12000, 16000, 40000] 
        else:
            l_x = [0, 95, 125, 150, 175, 210, 380]

        l_l_y_x = [l_x]
    elif CHANNEL == "gtis"     and TRIGGER == "HLT":
        ofilename="{}/HLTGTIS_{}_{}.root".format(data.outdir,"NA", YEAR)
    
        VAR_X="B(p_{T})[MeV]"
        VAR_Y="B(#eta)"

        l_y = [0, 8]
        if version not in ['v4', 'v5']:
            l_x = [0, 4000, 7000, 11000, 40000] 
        else:
            l_x = [0, 95, 125, 150, 175, 210, 380]
    
        l_l_y_x = [l_x]
    #----------------------------------------------------------------
    else:
        log.error("Wrong settings:")
        log.error("Channel: " + CHANNEL)
        log.error("Trigger: " + TRIGGER)
        raise

    nrow_1 = len(l_l_y_x)
    nrow_2 = len(l_y) - 1

    if nrow_1 != nrow_2:
        log.error('Number of rows ({})  is not equal to number of bins in y ({}), for {}, {}'.format(nrow_1, nrow_2, CHANNEL, TRIGGER))
        print(l_l_y_x)
        print(l_y)
        raise
    
    log.debug("Making histogram")
    h_poly=getHistogram(VAR_X, VAR_Y, l_l_y_x, l_y)
    
    log.debug("Filling histogram")
    h_poly=fillHistogram(h_poly)

    if version > 'v2':
        h_poly=set_title(h_poly, CHANNEL, TRIGGER, version)
    
    ofile=ROOT.TFile(ofilename, "recreate")
    h_poly.Write()
    ofile.Close()
#-------------------------------
def set_title(hist, chan, trig, vers):
    if   trig == 'HLT':
        exp_x, exp_y = 'B_PT', 'B_ETA'
    elif trig == 'L0' and chan in ['electron', 'electron_fac']:
        exp_x, exp_y = 'L1_L0Calo_ECAL_realET', 'L1_L0Calo_ECAL_region'
    elif trig == 'L0' and chan == 'hadron':
        exp_x, exp_y = 'H_L0Calo_HCAL_realET', 'H_L0Calo_HCAL_region'
    elif trig == 'L0' and chan == 'muon':
        exp_x, exp_y = 'L2_PT', 'L2_ETA'
    elif trig == 'L0' and chan == 'gmh':
        exp_x, exp_y = 'B_PT', 'B_ETA'
    elif trig == 'L0' and chan in ['gem', 'gbn']:
        exp_x, exp_y = 'max_lep_pt', 'B_ETA'
    else:
        log.error(f'Invalid channel and/or trigger: {chan}/{trig}')
        raise ValueError

    if vers in ['v4', 'v5']: 
        xaxis = hist.GetXaxis()
        xaxis.SetTitle('nTracks')

        if trig == 'HLT':
            exp_x = 'nTracks'

    hist.SetTitle(f'{exp_x}:{exp_y}')

    return hist
#-------------------------------
def main():
    logzero.loglevel(logzero.INFO)

    out_dir = files('tools_data').joinpath('trigger')
    for version in ['v0', 'v1', 'v3', 'v4', 'v5']:
        data.outdir=f'{out_dir}/{version}'
        os.makedirs(data.outdir, exist_ok=True)
        for year in [2011, 2012, 2015, 2016, 2017, 2018]:
            log.info('Running for {}'.format(year))
    
            makeMap(    "muon"    ,  "L0", year, version)
            makeMap("electron"    ,  "L0", year, version)
            makeMap("electron_fac",  "L0", year, version)
            makeMap(  "hadron"    ,  "L0", year, version)
            makeMap(     "gem"    ,  "L0", year, version)
            makeMap(     "gmh"    ,  "L0", year, version)
            makeMap(     "gbn"    ,  "L0", year, version)
    
            makeMap(    "muon"    , "HLT", year, version)
            makeMap("electron"    , "HLT", year, version)
            makeMap(  "hadron"    , "HLT", year, version) 
            makeMap(    "gtis"    , "HLT", year, version)
#-------------------------------
if __name__ == '__main__':
    main()


