#!/usr/bin/env python3

import os
import re
import ROOT
import numpy
import pprint
import argparse
import pandas             as pnd
import utils_noroot       as utnr
import matplotlib.pyplot  as plt
import version_management  as vm
import selection.utilities as slut

from scipy.stats import gaussian_kde
from log_store   import log_store

from importlib.resources import files

import read_selection as rs
import extset

log=log_store.add_logger('rk_extractor:check_dist')
#------------------------------------
def get_lumi_wgt():
    d_lumi = {'2011' : 1, '2012' : 2, '2015' : 0.3, '2016' : 1.6, '2017' : 1.7, '2018' : 2.1}
    l_lumi = [lumi for lumi in d_lumi.values()]
    tot_lum= sum(l_lumi)
    d_wgt  = { year : lum/tot_lum for year, lum in d_lumi.items()} 

    return d_wgt
#------------------------------------
class data:
    l_kind = ['bdt_cmb_sig', 'mass_qsq', 'bdt_tran', 'jpsi_leak']
    l_year = ['2011', '2012', '2015', '2016', '2017', '2018']
    l_trig = ['MTOS', 'ETOS', 'GTIS']
    d_lum_wgt = get_lumi_wgt()

    plt_dir= None
#------------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Used to make plots of different distributions')
    parser.add_argument('-k', '--kind' , type =str, help='Kind of plot', choices=data.l_kind, default=data.l_kind)
    parser.add_argument('-y', '--year' , nargs='+', help='Datasets'    , choices=data.l_year, default=data.l_year)
    parser.add_argument('-t', '--trig' , nargs='+', help='Triggers'    , choices=data.l_trig, default=data.l_trig)
    parser.add_argument('-p', '--pdir' , type =str, help='Plotting directory', required=True)
    args = parser.parse_args()

    data.kind   = args.kind
    data.l_year = args.year
    data.l_trig = args.trig
    data.plt_dir= utnr.make_dir_path(args.pdir)
#------------------------------------
def check_dist(kind):
    if   kind == 'bdt_cmb_sig':
        check_bdt_cmb_sig() 
    elif kind == 'mass_qsq':
        check_mass_qsq() 
    elif kind == 'bdt_tran':
        check_bdt_tran() 
    elif kind == 'jpsi_leak':
        check_jpsi_leak() 
    else:
        log.error(f'Kind not recognized: {kind}')
        raise
#------------------------------------
def get_label(arr_mass, trig):
    trig = 'eTOS' if trig == 'ETOS' else 'gTIS!'
    size = len(arr_mass)

    return f'{trig}, {size}'
#------------------------------------
def check_jpsi_leak():
    cas_dir = os.environ['CASDIR']
    for trig in data.l_trig:
        l_arr_data = []
        for year in data.l_year:
            log.info(f'Retrieving jpsi leaking MC for: {trig}/{year}')
            qsq_cut = rs.get('q2' , trig, 'high', year)
            pid_cut = rs.get('pid', trig, 'high', year)
            bdt_cut = rs.get('bdt', trig, 'high', year)

            root_wc = f'{cas_dir}/tools/apply_selection/q2_smear/ctrl/v10.21p2/{year}_{trig}/*.root'
            rdf = ROOT.RDataFrame(trig, root_wc)
            rdf = rdf.Filter(qsq_cut, 'qsq_cut')
            rdf = rdf.Filter(pid_cut, 'pid_cut')
            rdf = rdf.Filter(bdt_cut, 'bdt_cut')
            rep = rdf.Report()
            rep.Print()

            arr_mass = rdf.AsNumpy(['B_M'])['B_M']
            wgt      = data.d_lum_wgt[year]
            log.info(f'Using weight {wgt:.3f} for {year}')
            arr_wgt  = numpy.full_like(arr_mass, wgt)
            arr_data = numpy.array([arr_mass, arr_wgt]).T

            l_arr_data.append(arr_data)

        arr_data = numpy.concatenate(l_arr_data)
        [arr_mass, arr_wgt] = arr_data.T
        lab = get_label(arr_mass, trig)
        plt.hist(arr_mass, bins=50, range=(5500, 7000), label=lab, histtype='step', weights=arr_wgt, density=True)

    plt.legend()
    plt.xlabel('M(B)[MeV]')
    plt.ylabel('Normalized/30MeV')
    log.info(f'Saving to: {data.plt_dir}/leak.png')
    plt.grid()
    plt.tight_layout()
    plt.savefig(f'{data.plt_dir}/leak.png')
    plt.close('all')
#------------------------------------
def check_bdt_tran():
    l_xval = [0.01 * x for x in range(-100, 101)]
    l_yval = [ slut.transform_bdt(x) for x in l_xval ]

    plt.plot(l_xval, l_yval)
    plt.xlabel('Original')
    plt.ylabel('Transformed')
    overlay_bdt_bounds(vertical=False, recalculated=False)
    plt.savefig(f'{data.plt_dir}/bdt_tran.png')
    plt.close('all')
#------------------------------------
def check_mass_qsq():
    cas_dir = os.environ['CASDIR']
    dat_dir = f'{cas_dir}/tools/apply_selection/blind_fits/data'
    for year in data.l_year:
        for trig in data.l_trig:
            dat_wc    = f'{dat_dir}/v10.21p2/{year}_{trig}/*.root'
            log.info(f'Picking up data from: {dat_wc}')
            rdf       = ROOT.RDataFrame(trig, dat_wc)
            d_dat     = rdf.AsNumpy(['Jpsi_M', 'B_M'])
            df        = pnd.DataFrame(d_dat)
            df['qsq'] = df.Jpsi_M ** 2

            dat = numpy.vstack([df.B_M, df.qsq])
            z   = gaussian_kde(dat)(dat)

            plt.scatter(df.B_M, df.qsq, c=z, s=1)
            log.info(f'Saving to: {data.plt_dir}/mass_qsq.png')
            add_mas_qsq_lines(trig, year)
            plt.savefig(f'{data.plt_dir}/mass_qsq.png')
            plt.close('all')
#------------------------------------
def get_lims(rgx, cut):
    mtch = re.match(rgx, cut)
    if not mtch:
        log.error(f'Cannot match {rgx} in {cut}')
        raise

    mins, maxs = mtch.groups()

    return float(mins), float(maxs)
#------------------------------------
def add_mas_qsq_lines(trig, year):
    q2 = rs.get('q2'  , trig, q2bin='high', year=year)
    ms = rs.get('mass', trig, q2bin='high', year=year)

    rgx_q2 = r'\(Jpsi_M \* Jpsi_M > ([0-9\.]+)\) && \(Jpsi_M \* Jpsi_M < ([0-9\.]+)\)'
    rgx_ms = r'\(B_M\s+> ([0-9]+)\) && \(B_M\s+< ([0-9]+)\)'

    minq2, maxq2 = get_lims(rgx_q2, q2)
    minms, maxms = get_lims(rgx_ms, ms)

    plt.axhline(y=minq2, color='r')
    plt.axhline(y=maxq2, color='r')
    plt.axvline(x=minms, color='r')
    plt.axvline(x=maxms, color='r')
#------------------------------------
def get_bounds(recalculated=False):
    l_bound = []
    if not recalculated:
        for bdt_bin in [1, 2, 3, 4, 5]:
            d_bin, _, _     = extset.get_bdt_bin_settings(bdt_bin)
            bound           = d_bin['BDT_cmb']
            l_bound.append(bound)

        return l_bound

    d_bin, _, _     = extset.get_bdt_bin_settings(5)
    [min_x, max_x]  = d_bin['BDT_cmb']
    min_x = slut.inverse_transform_bdt(min_x)
    max_x = slut.inverse_transform_bdt(max_x)
    dx    = max_x - min_x
    l_bound.append([min_x, max_x])
    for bdt_bin in [1, 2, 3, 4]:
        max_x = min_x
        min_x = max_x - bdt_bin * dx

        l_bound.append([min_x, max_x])

    return l_bound
#------------------------------------
def overlay_bdt_bounds(vertical=True, recalculated=False):
    l_bound = get_bounds(recalculated=recalculated)
    for [min_x, max_x] in l_bound:
        if vertical:
            plt.axvline(x=min_x, linestyle='-.', linewidth=0.3, color='r')
            plt.axvline(x=max_x, linestyle='-.', linewidth=0.3, color='r')
        else:
            plt.axhline(y=min_x, linestyle='-.', linewidth=0.3, color='r')
            plt.axhline(y=max_x, linestyle='-.', linewidth=0.3, color='r')
#------------------------------------
def check_bdt_cmb_sig():
    dir_path = files('extractor_data').joinpath('sig_wgt')
    ver_path = vm.get_last_version(dir_path=dir_path, version_only=False) 
    for trig in data.l_trig:
        for year in data.l_year:
            json_path = f'{ver_path}/{trig}_{year}.json'
            df = pnd.read_json(json_path)
            df['BDT_org'] = df.BDT_cmb.apply(slut.inverse_transform_bdt)
            plt.hist(df.BDT_cmb, bins=50, weights=df.wgt, label=f'{trig}; {year}', histtype='step')
            plt.hist(df.BDT_org, bins=50, weights=df.wgt, label=f'{trig}; {year}', histtype='step')
            overlay_bdt_bounds()

    plt.legend(['Transformed', 'Original'])
    plt.yscale('log')
    plt.savefig(f'{data.plt_dir}/bdt_cmb_sig.png')
    plt.close('all')
#------------------------------------
def main():
    get_args()
    check_dist(data.kind)
#------------------------------------
if __name__ == '__main__':
    main()

