#!/usr/bin/env python3

from misid_check import misid_check

import os
import glob
import ROOT
import argparse
import read_selection    as rs
import pandas            as pnd
import matplotlib.pyplot as plt

from log_store import log_store

log=log_store.add_logger('rk_extractor:check_jpsi_misid')
#-------------------------------------
class data:
    cas_dir = os.environ['CASDIR']
    out_dir = None
    kind    = None

    low_q2  = 15500000
    hig_q2  = 22000000
#-------------------------------------
def get_rdf(trig=None):
    year    = '*' if data.kind == 'data' else '2018'
    file_wc = f'{data.cas_dir}/tools/apply_selection/no_pid_qsq/{data.kind}/v10.21p2/{year}_{trig}/*.root'
    l_file  = glob.glob(file_wc)
    rdf     = ROOT.RDataFrame(trig, l_file)
    qsq     = rs.get( 'q2', trig, q2bin='high', year = '2018')
    pid     = rs.get('pid', trig, q2bin='high', year = '2018')
    rdf     = rdf.Filter(pid, 'pid')
    if data.kind == 'data':
        rdf = rdf.Filter(qsq, 'qsq')

    nentries = rdf.Count().GetValue()
    if nentries == 0:
        log.error(f'No entries left after selection')
        rep = rdf.Report()
        rep.Print()
        exit()

    return rdf
#-------------------------------------
def check_misid():
    rdf    = get_rdf(trig='MTOS')
    df_org = misid_check.rdf_to_df(rdf, '(Jpsi|B|L1|L2|H)_(P\w|ID|M)$')
    obj    = misid_check(df_org, d_lep={'L1' : 13, 'L2' : 13}, d_had={'H' : 13})
    df_swp = obj.get_df()
    df     = pnd.concat([df_org, df_swp], axis=1)

    df['qsq_swp']    = df.H_swp  ** 2
    df['qsq_org']    = df.H_org  ** 2
    df['Jpsi_M_sqr'] = df.Jpsi_M ** 2

    plot_jmass(df)
    plot_bmass(df)
    if data.kind == 'ctrl':
        plot_swp_jpsi(df)
    else:
        plot_org_jpsi(df)
#-------------------------------------
def plot_bmass(df):
    os.makedirs(data.out_dir, exist_ok=True)

    if data.kind == 'ctrl':
        df=df[(df.qsq_swp > data.low_q2) & (df.qsq_swp < data.hig_q2)]

    nentries = len(df)

    df.B_M.hist(bins=40, range=(5000, 5600) , histtype='step', label='')
    plt.grid(False)
    plt.title(f'#Entries={nentries}')
    plot_path = f'{data.out_dir}/bmass_{data.kind}.png'
    plt.xlabel('$M(B)$')
    plt.ylabel('Entries')
    plt.savefig(plot_path)
    plt.close('all')
#-------------------------------------
def plot_jmass(df):
    os.makedirs(data.out_dir, exist_ok=True)

    df.H_swp.hist(bins=60, range=(2750, 3800) , histtype='step', label=r'$K^{\pm}_{\to\mu^{\pm}}$')
    plt.axvline(x=3097, color='r', linestyle=':', label='$J/\psi$')
    plt.axvline(x=3686, color='k', linestyle=':', label='$\psi(2S)$')
    plt.grid(False)
    plt.xlabel('$M(J/\psi)$')
    plt.ylabel('Entries')
    plt.legend()

    plot_path = f'{data.out_dir}/jmass_{data.kind}.png'
    plt.savefig(plot_path)
    plt.close('all')
#-------------------------------------
def plot_swp_jpsi(df):
    os.makedirs(data.out_dir, exist_ok=True)
    df_leak = df[df.qsq_swp > data.low_q2]

    nlk = len(df_leak)
    nto = len(df)

    df.qsq_swp.hist(bins=40, range=(0, 25000000) , histtype='step', label=r'$K^{\pm}_{\to\mu^{\pm}}$')
    plt.axvline(x=data.low_q2, color='r', linestyle=':')
    plt.axvline(x=data.hig_q2, color='r', linestyle=':')
    plt.title(f'Leak={nlk}/{nto}')
    plt.grid(False)
    plt.xlabel('$M(J/\psi)$')
    plt.ylabel('Entries')
    plt.legend()

    plot_path = f'{data.out_dir}/swp_jpsi_{data.kind}.png'
    plt.savefig(plot_path)
    plt.close('all')
#-------------------------------------
def plot_org_jpsi(df):
    os.makedirs(data.out_dir, exist_ok=True)

    df.Jpsi_M_sqr.hist(bins=40, range=(0, 25000000) , histtype='step')
    plt.axvline(x=data.low_q2, color='r', linestyle=':')
    plt.axvline(x=data.hig_q2, color='r', linestyle=':')
    plt.grid(False)
    plt.xlabel('$M(J/\psi)$')
    plt.ylabel('Entries')
    plt.legend()

    plot_path = f'{data.out_dir}/org_jpsi_{data.kind}.png'
    plt.savefig(plot_path)
    plt.close('all')
#-------------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Used to perform several operations on TCKs')
    parser.add_argument('-d', '--out_dir' , type=str, help='Output directory', required=True)
    parser.add_argument('-k', '--kind'    , type=str, help='Sample type'     , required=True, choices=['data', 'ctrl'])
    args = parser.parse_args()

    data.out_dir = args.out_dir
    data.kind    = args.kind
#-------------------------------------
def main():
    get_args()
    check_misid()
#-------------------------------------
if __name__ == '__main__':
    main()

