from typing import List, Tuple
import numpy as np
import itertools
import numpy as np

class Rhombus:
    #原点は菱形の立体の中心
    # l:  シンチレータの横幅[cm]
    l = 4
    theta = np.radians(7)
    # x_len, y_len xが短いほうの対角線
    x_len = l * np.sin(theta) / np.sin(2*theta)
    y_len = l * np.cos(theta) / np.sin(2*theta)
    #シンチレータのz
    z_len = 2

    def __init__(self, index: int, origin_point: Tuple[float, float, float]) -> None:
        self._origin_point = np.array(origin_point)
        self._index = index
        self.create_polygon()

    def create_polygon(self) -> None:
        self._vertices = np.array([
            self._origin_point + [0, self.y_len, self.z_len],
            self._origin_point + [self.x_len, 0, self.z_len],
            self._origin_point + [0, -self.y_len, self.z_len],
            self._origin_point + [-self.x_len, 0, self.z_len],
            self._origin_point + [0, self.y_len, -self.z_len],
            self._origin_point + [self.x_len, 0, -self.z_len],
            self._origin_point + [0, -self.y_len, -self.z_len],
            self._origin_point + [-self.x_len, 0, -self.z_len]
        ])

        #三角形の組み合わせ
        self._faces = np.array(
            [v for v in itertools.combinations([0, 1, 2, 3], 3)] + 
            [v for v in itertools.combinations([4, 5, 6, 7], 3)] + 
            [v for v in itertools.combinations([0, 1, 4, 5], 3)] + 
            [v for v in itertools.combinations([1, 2, 5, 6], 3)] + 
            [v for v in itertools.combinations([2, 3, 6, 7], 3)] + 
            [v for v in itertools.combinations([3, 0, 7, 4], 3)]
        )
    
    def get_index(self) -> int:
        return self._index

    
    def get_vertices(self):
        return self._vertices
    
    def get_faces(self):
        return self._facesfrom tqdm import tqdm
import ROOT as r
import numpy as np
from . import calibrationUtils as util


class TrackSeeker(r.TChain):
    OUTER_CHNNELS = set([
        0, 1, 2, 3, 60, 61, 62, 63,
        28, 29, 30, 31, 32, 33, 34, 35
        ])
    INNER_CHANNELS = set(range(64)) - OUTER_CHNNELS
    VERTICAL_GROUP_EAST_BOARD = tuple(
        tuple(j + 4*i for i in range(8))
        for j in range(4)
    )
    VERTICAL_GROUP_WEST_BOARD = tuple(
        tuple(reversed([32 + j + 4*i for i in range(8)]))
        for j in range(4)
    )
    VERTICAL_GROUP = VERTICAL_GROUP_EAST_BOARD + VERTICAL_GROUP_WEST_BOARD

    def __init__(self, name, filepath):
        # recieve args
        super().__init__(name)
        self._filepath = filepath
        # TChain settings
        self.SetBranchStatus("*", 0)
        self.SetBranchStatus("VadcHigh")
        self.Add(self._filepath)
        # prepare member variables
        self._prepare_variables()
        # exec member functions
        self.fetch_hist()

    def _prepare_variables(self):
        self._n_event = self.GetEntries()
        self._threshold = [0 for _ in range(64)]
        self._landau_fit_range = [[1200, 2700] for _ in range(64)]
        self._f_landau = [r.TF1("f_landau{}".format(ch), "landau", 0, 4096) for ch in range(64)]
        self._is_hit = []
        self._effeciency = [None for _ in range(64)]

    def set_threshold(self, ch, adc_th):
        self._threshold[ch] = adc_th

    def fetch_hist(self):
        self._hist = [None for _ in range(64)]
        for ch in range(64):
            self._hist[ch] = util.getHistMPPC(self._filepath, ch)
        print("loaded VadcHigh 64ch as histogram")

    def set_landau_fit_range(self, ch, fit_range_min, fit_range_max):
        self._landau_fit_range[ch] = [fit_range_min, fit_range_max]

    def fit_by_landau(self, ch):
        fitmin, fitmax = self._landau_fit_range[ch]
        self._hist[ch].Fit(self._f_landau[ch], "", "", fitmin, fitmax)

    def save_hist(self):
        canvas = r.TCanvas("c", "c", 1920*2, 1080*16)
        canvas.Divide(4, 16)
        for ch in range(64):
            canvas.cd(ch+1),
            self._hist[ch].Draw()
        save_name = self._filepath.split('/')[-1].replace(".root", ".png")
        canvas.SaveAs(save_name)

    def determine_hit_by_landau_fit(self):
        tree = r.TTree("tree_bool", "tree_bool")
        is_hit = np.zeros(64, dtype=bool)
        tree.Branch("is_hit", is_hit, "is_hit[64]/O")

        self._threshold = np.array(self._threshold)
        for i_event in tqdm(range(self._n_event), desc="making is_hit", leave=False):
            self.GetEntry(i_event)
            VadcHigh = np.array(self.VadcHigh)
            is_hit = self._threshold < VadcHigh
            self._is_hit.append(is_hit)
            tree.Fill()
        # TODO: consider hit info saving
        tree.SaveAs("tmp.root")

    def calc_effeciency(self, ch_target):
        """
        this method is only used inner channel
        """
        if ch_target not in self.INNER_CHANNELS:
            print("ch {} is not inner scintillator!".format(ch_target))
            exit()
        # search upside, downside channel
        for i, group in enumerate(self.VERTICAL_GROUP):
            if ch_target in group:
                vertical_group_index = i
        for i, ch_candi in enumerate(self.VERTICAL_GROUP[vertical_group_index]):
            if ch_target == ch_candi:
                ch_down = self.VERTICAL_GROUP[vertical_group_index][i-1]
                ch_up = self.VERTICAL_GROUP[vertical_group_index][i+1]
        print("target channel of effeciency calculation is {}, up{}, down{}".format(ch_target, ch_up, ch_down))
        # calc effeciency
        ok = 0
        ng = 0
        for i_event in tqdm(range(self._n_event), desc="calc eff ch{}".format(ch_target), leave=False):
            if self._is_hit[i_event][ch_up] and self._is_hit[i_event][ch_down]:
                if self._is_hit[i_event][ch_target]:
                    ok += 1
                else:
                    ng += 1
            else:
                continue
        if (ok+ng) == 0:
            print("There is no event to calc effeciency ch {}".format(ch_target))
        else:
            self._effeciency[ch_target] = ok / (ok + ng)
            print("detectation effeciency of ch {} is {:.5f}%".format(
                ch_target,
                100*self._effeciency[ch_target]
            ))
            print("used {} events".format(ok+ng))
            from typing import List
import numpy as np
import uproot

def get_threshold():
    threshold = [1300 for _ in range(64)]
    threshold[1] = 1150
    threshold[2] = 1000
    threshold[4] = 1200
    threshold[5] = 1150
    threshold[11] = 1100
    threshold[12] = 1200
    threshold[14] = 1100
    threshold[28] = 1100
    threshold[35] = 1050
    threshold[37] = 1000
    threshold[57] = 1200
    threshold[59] = 1200
    threshold[61] = 1100
    return threshold


class EffCalculatorUpDown:
    CHANNELS_UPSIDE = np.array([[i for i in range(0 + (4 * layer), 4 + (4 * layer))] for layer in range(8)])
    CHANNELS_DOWNSIDE = np.fliplr(np.array([[i for i in range(60 - (4 * layer), 64 - (4 * layer))] for layer in range(8)]))
    OUTER_UP_CHNNELS = set([
        28, 29, 30, 31, 32, 33, 34, 35
    ])
    OUTER_DOWN_CHNNELS = set([
        0, 1, 2, 3, 60, 61, 62, 63
    ])
    OUTER_CHNNELS = OUTER_UP_CHNNELS | OUTER_DOWN_CHNNELS
    INNER_CHANNELS = set(range(64)) - OUTER_CHNNELS
    VERTICAL_GROUP_EAST_BOARD = tuple(
        tuple(j + 4*i for i in range(8))
        for j in range(4)
    )
    VERTICAL_GROUP_WEST_BOARD = tuple(
        tuple(reversed([32 + j + 4*i for i in range(8)]))
        for j in range(4)
    )
    VERTICAL_GROUP = VERTICAL_GROUP_EAST_BOARD + VERTICAL_GROUP_WEST_BOARD

    def __init__(self, tree_name, rootfile_path):
        self._load_rootfile(rootfile_path, tree_name)

    def _load_rootfile(self, rootfile_path, tree_name):
        self._rootfile_path = rootfile_path
        with uproot.open(self._rootfile_path) as file:
            self._tree = file[tree_name]
            self._VadcHigh = self._tree["VadcHigh"].array(library="np")

    def set_ref_threshold_s(self, ref_threshold_s: List[int]):
        self._ref_threshold_s = np.array(ref_threshold_s)
        self._make_ref_hit()

    def set_threshold_s(self, threshold_s: List[int]):
        self._threshold_s = np.array(threshold_s)
        self._make_is_hit()
        self._search_ref_events()

    def _make_ref_hit(self):
        self._ref_hit = self._VadcHigh > self._ref_threshold_s

    def _make_is_hit(self):
        self._is_hit = self._VadcHigh > self._threshold_s

    def _search_ref_events(self):
        self._use_index_s = dict()
        # inner channels
        for ch in self.INNER_CHANNELS:
            ch_up = self._get_upside_channel(ch)
            ch_down = self._get_downside_channel(ch)
            use_index = np.where(
                self._ref_hit.T[ch_up] * self._ref_hit.T[ch_down] == 1
            )[0]
            self._use_index_s[ch] = use_index
        # upside channels
        for ch in self.OUTER_UP_CHNNELS:
            ch_down = [None for _ in range(7)]
            ch_down[0] = self._get_downside_channel(ch)
            for i in range(1, 7):
                ch_down[i] = self._get_downside_channel(ch_down[i-1])
            ch_and = np.ones(shape=(len(self._VadcHigh)), dtype=np.bool)
            for i in range(7):
                ch_and = ch_and * self._ref_hit.T[ch_down[i]]
            use_index = np.where(
                ch_and == 1
            )[0]
            self._use_index_s[ch] = use_index
        # down side channels
        for ch in self.OUTER_DOWN_CHNNELS:
            ch_up = [None for _ in range(7)]
            ch_up[0] = self._get_upside_channel(ch)
            for i in range(1, 7):
                ch_up[i] = self._get_upside_channel(ch_up[i-1])
            ch_and = np.ones(shape=(len(self._VadcHigh)), dtype=np.bool)
            for i in range(7):
                ch_and = ch_and * self._ref_hit.T[ch_up[i]]
            use_index = np.where(
                ch_and == 1
            )[0]
            self._use_index_s[ch] = use_index

    def _calc_effeciency(self, ch) -> float:
        use_index = self._use_index_s[ch]
        ok_index = np.where(
            self._is_hit[use_index].T[ch] == 1
        )[0]
        return ok_index.shape[0] / use_index.shape[0]

    def calc_all_ch_effeciency(self):
        return list(map(self._calc_effeciency, range(64)))

    def _get_upside_channel(self, ch_target):
        for i, group in enumerate(self.VERTICAL_GROUP):
            if ch_target in group:
                vertical_group_index = i
        for i, ch_candi in enumerate(self.VERTICAL_GROUP[vertical_group_index]):
            if ch_target == ch_candi:
                ch_up = self.VERTICAL_GROUP[vertical_group_index][i+1]
        return ch_up

    def _get_downside_channel(self, ch_target):
        for i, group in enumerate(self.VERTICAL_GROUP):
            if ch_target in group:
                vertical_group_index = i
        for i, ch_candi in enumerate(self.VERTICAL_GROUP[vertical_group_index]):
            if ch_target == ch_candi:
                ch_down = self.VERTICAL_GROUP[vertical_group_index][i-1]
        return ch_down
import os
import json
import array
import ROOT as r
from copy import copy

# TGraph can using python list


def TPGraphErrors(n, x, y, x_e, y_e):
    x = array.array('d', x)
    y = array.array('d', y)
    x_e = array.array('d', x_e)
    y_e = array.array('d', y_e)
    return r.TGraphErrors(n, x, y, x_e, y_e)

# .root => hist


def getHistMPPC(file_path, channel):
    file = r.TFile(file_path)
    hist = file.Get("ADC_HIGH_" + str(channel))
    hist.SetTitle(file_path.replace("data/cal_2021", "").replace(".root", "_") + str(channel) + "ch;ADC;Events")
    return copy(hist)


def searchPeaks(hist, peak_max, sigma=10):
    spectrum = r.TSpectrum(peak_max)
    spectrum.Search(hist, sigma, "new")
    x_peaks = spectrum.GetPositionX()
    y_peaks = spectrum.GetPositionY()
    n_peaks = int(spectrum.GetNPeaks())
    ret_x_peaks = [x_peaks[i] for i in range(n_peaks)]
    ret_y_peaks = [y_peaks[i] for i in range(n_peaks)]
    ret_x_peaks, ret_y_peaks = zip(*sorted(zip(ret_x_peaks, ret_y_peaks)))
    return n_peaks, ret_x_peaks, ret_y_peaks

# unko function


def getMultiGaussString(num):
    gausses_str = ""
    for i in range(num):
        gausses_str += "+gaus(" + str(3*i) + ")"
    return gausses_str

# fit sutego zaurusu


def getFittedParams(
    hist,
    peak_search_range=(0, 1500),
    fitting_range=(0, 1500),
    showing_range=(0, 1500),
    peak_search_sigma=10
):
    # peak search
    hist.GetXaxis().SetRangeUser(*peak_search_range)
    n_peaks, x_peaks, y_peaks = searchPeaks(hist, 100, peak_search_sigma)
    multi_gauss_str = getMultiGaussString(n_peaks)

    # fitting
    f_fit = r.TF1("f", multi_gauss_str, *fitting_range)
    f_fit.SetNpx(10000)
    for i in range(n_peaks):
        f_fit.SetParName(3*i + 0, str(i) + "th const")
        f_fit.SetParName(3*i + 1, str(i) + "th mean")
        f_fit.SetParName(3*i + 2, str(i) + "th sigma")
        f_fit.SetParameter(3*i + 0, y_peaks[i])
        f_fit.SetParameter(3*i + 1, x_peaks[i])
        f_fit.SetParameter(3*i + 2, 5)
        f_fit.SetParLimits(3*i + 0, 0, 10**6)
        f_fit.SetParLimits(3*1 + 2, 0, 100)
    hist.Fit(f_fit, "R")

    # set showing range
    hist.GetXaxis().SetRangeUser(*showing_range)

    # return
    ret_adc_means = [f_fit.GetParameter(3*i + 1) for i in range(n_peaks)]
    ret_adc_mean_errors = [f_fit.GetParError(3*i + 1) for i in range(n_peaks)]
    return ret_adc_means, ret_adc_mean_errors

# make calibration line
# it returns pol1 params


def getCalibrationParams(
    json_file_path="/home/hamada/b4ex/ensoku/json/cal_20211112_16_36.json",
):
    # fetch json, hist
    settings = json.load(open(json_file_path))
    title = settings["root_file_path"] + " " + str(settings["target_channel"])
    peak_search_sigma = settings.get("peak_search_sigma", 10)
    hist = getHistMPPC(settings["root_file_path"], settings["target_channel"])

    # prepare dir
    os.makedirs(settings["image_save_path"], exist_ok=True)

    # fetch fitting params
    adc_means, adc_mean_errors = getFittedParams(
        hist,
        settings["peak_search_range"],
        settings["fitting_range"],
        settings["showing_range"],
        peak_search_sigma
    )

    # init graph
    n_points = len(adc_means)
    photon_nums = [settings["initial_photon_num"] + i for i in range(n_points)]
    photon_num_errors = [0 for _ in range(n_points)]
    g = TPGraphErrors(n_points, photon_nums, adc_means, photon_num_errors, adc_mean_errors)
    g.SetTitle(title + ";Photon Number;ADC Value")
    g.SetMarkerStyle(8)
    g.SetMarkerSize(1)

    # init liner function for fitting and fit
    f_fit = r.TF1("f_liner", "[0]*x + [1]", 0, 20)
    g.Fit(f_fit, "R")

    # init axis for Tgraph
    photon_num_range = (0, photon_nums[-1] + 1)
    adc_range = tuple(map(f_fit.Eval, photon_num_range))
    axis = r.TH2D(
        "axis", title + ";Photon Number;ADC Value",
        0, *photon_num_range,
        0, *adc_range
    )
    axis.SetStats(0)

    # save image
    c1 = r.TCanvas()
    hist.Draw()
    c1.SaveAs(settings["image_save_path"] + "hist.png")
    c2 = r.TCanvas()
    axis.Draw("AXIS")
    g.Draw("P SAME")
    c2.SaveAs(settings["image_save_path"] + "graph.png")

    # return
    # y   = ax         + b
    # ADC = a * Photon + b
    a, b = f_fit.GetParameter(0), f_fit.GetParameter(1)
    return a, b
from re import S
import os
from .TrackReconstructorBase import TrackReconstructorBase
import numpy as np
from tqdm import tqdm
import itertools


class MuonTrackReconstructor(TrackReconstructorBase):
    def __init__(self, rootfile_path, threshold_s) -> None:
        super().__init__(rootfile_path, threshold_s)
        self._hit_array = self._hit_array_gen.get_hit_array()
        print(self._hit_array.shape[0])
    
    def _pre_cut_threshold_layer(self):
        self._layer_n_hit = np.sum(self._hit_array, axis = (2,3))
        self._hit_layer_number = np.count_nonzero(self._layer_n_hit, axis=1)
        self._threshold_layer_number = 6
        self._pre_cut_index = np.where(self._hit_layer_number >= self._threshold_layer_number)[0]
        self._pre_cut_array = self._hit_array[self._pre_cut_index]
        print("preselected, {:.2f}% remain".format(
            100 * self._pre_cut_array.shape[0] / self._hit_array.shape[0]
        ))

    # 全ての層を見て、1つの層以上で決めたヒット数よりも多いものを選び出す
    def _multi_hit(self):
        self._multi_hit_event_index = []
        self._threshold_hit = 4
        for i_event in tqdm(self._pre_cut_index):
            if (np.any(self._layer_n_hit[i_event] >= self._threshold_hit)):
                self._multi_hit_event_index.append(i_event)

        print("multi_hit event is ",len(self._multi_hit_event_index))
        print("selected multi_hit event, {:.2f}% remain".format(
            100 * len(self._multi_hit_event_index) / self._hit_array.shape[0]
        )) 

    # 決めた層よりも下の層でいくつ鳴ったかでふるいに掛ける
    def _under_layer_limit(self):
        self._under_layer_limit_index = []
        self._origin_layar = 6
        # hit_arrayは上から下層の情報なので8 - origin_layer
        self._orogin_layer_under = 8 - self._origin_layar
        for i_event in tqdm(self._multi_hit_event_index):
            if (np.any(self._layer_n_hit[i_event][0: self._orogin_layer_under] >= self._threshold_hit)):
                self._under_layer_limit_index.append(i_event)

        print("multi_hit event & under layer cut is ",len(self._under_layer_limit_index))
        print("selected multi_hit & under layer cut event, {:.2f}% remain".format(
            100 * len(self._under_layer_limit_index) / self._hit_array.shape[0]
        )) 


    def hit_muon_straight(self):
        self._hit_muon_index = []
        
        for i_event in tqdm(self._under_layer_limit_index):
            for i_layer, i,j in itertools.product(range(8), range(4), range(4)):
                if(i_layer == 7): 
                    self._hit_muon_index.append(i_event)
                    break
                    
                if (self._hit_array[i_event][i_layer][i][j] == 1):
                    if(i == 0):
                        hit_slice = self._hit_array[i_event][i_layer + 1:i_layer + 3, 0:i+2, j-1:j+2]
                        if(np.all(hit_slice == 0)):
                            break
                        continue
                    
                    if(j == 0):
                        hit_slice = self._hit_array[i_event][i_layer + 1:i_layer + 3,i-1:i+2, 0:j+2]
                        if(np.all(hit_slice == 0)):
                            break
                        continue
                    
                    else:
                        hit_slice = self._hit_array[i_event][i_layer + 1:i_layer + 3,i-1:i+2, j-1:j+2]
                        if(np.all(hit_slice == 0)):
                            break
                        continue

        print("straight event & multi_hit event & under layer cut is ",len(self._hit_muon_index))
        print("selected straight event & multi_hit event & under layer cut, {:.2f}% remain".format(
            100 * len(self._hit_muon_index) / self._hit_array.shape[0]
        ))   

    def write_fig(self, i_event):
        filename_short = self._rootfile_path.split('/')[-1]
        self._save_directory_path = "img_{}_{}_layer_hits_{}_hits_under_cut_{}_layer".format(filename_short, self._threshold_layer_number, self._threshold_hit, self._origin_layar)
        os.makedirs(self._save_directory_path, exist_ok=True)
        self._fig.write_image(self._save_directory_path + "/event{}.png".format(i_event), scale=10)
        print(self._save_directory_path + "/event{}.png".format(i_event))
        self._fig.write_html(self._save_directory_path + "/event{}.html".format(i_event))
        print(self._save_directory_path + "/event{}.html".format(i_event))from typing import List, Tuple
import itertools
import numpy as np
import ROOT as r
import os
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from .Rhombus import Rhombus
from .HitArrayGen import HitArrayGen
r.gROOT.SetBatch()


class TrackReconstructorBase:
    SHOWING_RANGE = (-100, 100)
    X_LEN = Rhombus(0, [0, 0, 0]).x_len
    Y_LEN = Rhombus(0, [0, 0, 0]).y_len

    def __init__(self, rootfile_path: str, threshold_s: List[int]) -> None:
        if len(threshold_s) != 64:
            print("invalid threshold list length")
            exit()
        self._rootfile_path = rootfile_path
        self._filename_short = self._rootfile_path.split('/')[-1]
        self._hit_array_gen = HitArrayGen(self._rootfile_path)
        for ch in range(64):
            self._hit_array_gen.set_threshold(ch, threshold_s[ch])
        self._hit_array_gen.generate_hit_array()
        self._hit_array = self._hit_array_gen.get_hit_array()
        self._make_point()
        self._init_fig()

    def _make_point(self):
        self._point = np.zeros(shape=[8, 4, 4, 3], dtype=np.float)
        for i_layer in range(-4, 4):
            for i in range(-2, 2):
                for j in range(-2, 2):
                    origin_point = [self.X_LEN * (i - j), self.Y_LEN * (i + j), 8 * i_layer]
                    self._point[i_layer + 4][i + 2][j + 2] = np.array(origin_point, dtype=np.float)

    def get_point(self, i_layer, i, j):
        return self._point[i_layer][i][j]

    def _init_fig(self):
        self._fig = make_subplots(
            rows=2, cols=2,
            specs=[
                [{'type': 'mesh3d'}, {'type': 'mesh3d'}],
                [{'type': 'mesh3d'}, {}]
            ]
        )
        layout = go.Layout(
            scene=dict(
                xaxis=dict(nticks=1, range=self.SHOWING_RANGE,),
                yaxis=dict(nticks=1, range=self.SHOWING_RANGE,),
                zaxis=dict(nticks=1, range=self.SHOWING_RANGE,),
            )
        )
        self._fig.update_layout(layout)
        self._fig.update_layout(height=900, width=1500)
        self._fig.update_layout(scene2_aspectmode='data')
        self._fig.update_layout(scene3_aspectmode='data')
        self._fig.update_scenes(
            camera=dict(
                eye=dict(x=0.7, y=0, z=0)
            ),
            xaxis_showticklabels=False,
            yaxis_showticklabels=False,
            zaxis_showticklabels=False,
            row=1, col=1
        )
        self._fig.update_scenes(
            camera=dict(
                eye=dict(x=0, y=3.0, z=0)
            ),
            camera_projection_type="orthographic",
            xaxis_showticklabels=False,
            yaxis_showticklabels=False,
            zaxis_showticklabels=False,
            row=2, col=1
        )
        self._fig.update_scenes(
            camera=dict(
                eye=dict(x=0, y=0, z=2.7)
            ),
            camera_projection_type="orthographic",
            xaxis_showticklabels=False,
            yaxis_showticklabels=False,
            zaxis_showticklabels=False,
            row=1, col=2
        )
        self._fig.update_scenes(xaxis_visible=False, yaxis_visible=False, zaxis_visible=False)

    def draw_mesh(self, i_event):
        self._prepare_pixel_attributes(i_event)
        self.data_scinti_mesh = []
        for i_pixels, pix in enumerate(self.pixels):
            x, y, z = pix.get_vertices().T
            i, j, k = pix.get_faces().T
            self.data_scinti_mesh.append(
                go.Mesh3d(
                    x=x, y=y, z=z, i=i, j=j, k=k,
                    color=self.pix_color[i_pixels],
                    opacity=0.05 if self.pix_color[i_pixels] == "cyan" else 1
                )
            )
        for i in range(len(self.data_scinti_mesh)):
            self._fig.add_trace(self.data_scinti_mesh[i], row=1, col=1)
            self._fig.add_trace(self.data_scinti_mesh[i], row=1, col=2)
            self._fig.add_trace(self.data_scinti_mesh[i], row=2, col=1)
        self._fig.update_layout(title_text="{} event {}".format(self._rootfile_path, i_event))

    def write_fig(self, i_event):
        filename_short = self._filename_short
        os.makedirs("img_{}".format(filename_short), exist_ok=True)
        self._fig.write_image("img_{}/event{}.png".format(filename_short, i_event), scale=10)
        print("img_{}/event{}.png".format(filename_short, i_event))
        self._fig.write_html("img_{}/event{}.html".format(filename_short, i_event))
        print("img_{}/event{}.html".format(filename_short, i_event))

    def _prepare_pixel_attributes(self, i_event):
        self._i_event = i_event
        self.pixels = []
        self.pix_color = []
        self.pix_index = 0
        for i_layer in range(-4, 4):
            for i in range(-2, 2):
                for j in range(-2, 2):
                    pix = Rhombus(
                        self.pix_index,
                        [self.X_LEN * (i - j), self.Y_LEN * (i + j), 8 * i_layer]
                    )
                    if self._hit_array[i_event][i_layer + 4][i + 2, j + 2] == 1.0:
                        self.pix_color.append("black")
                    else:
                        self.pix_color.append("cyan")
                    self.pixels.append(pix)
                    self.pix_index += 1

    def draw_line(self, point_1: Tuple[float], point_2: Tuple[float], line_width=0.7):
        """
        USE AFTER _init_fig()
        """
        self._points = [point_1, point_2]
        edges = []
        for point in self._points:
            for dx in [line_width, -line_width]:
                for dy in [line_width, -line_width]:
                    edges.append([
                        point[0] + dx,
                        point[1] + dy,
                        point[2]
                    ])
        edges = np.array(edges)
        faces = np.array(tuple(itertools.combinations(range(len(edges)), 3)))
        edge_x, edge_y, edge_z = edges.T
        face_i, face_j, face_k = faces.T
        line_mesh_data = go.Mesh3d(
            x=edge_x, y=edge_y, z=edge_z,
            i=face_i, j=face_j, k=face_k,
            color="red"
        )
        self._fig.add_trace(line_mesh_data, row=1, col=1)
        self._fig.add_trace(line_mesh_data, row=1, col=2)
        self._fig.add_trace(line_mesh_data, row=2, col=1)

    def show(self, i_event):
        self._init_fig()
        self.draw_mesh(i_event)
        self.write_fig(i_event)
import enum
import numpy as np
from tqdm import tqdm
from .TrackSeeker import TrackSeeker
import ROOT as r
r.gROOT.SetBatch()


class EffCalculator(TrackSeeker):
    def __init__(self, n_hit, tree_name, filepath):
        super().__init__(tree_name, filepath)
        self._n_hit_required = n_hit
        self._exec_landau_fit()
        self._search_all_vertical_mu_event()

    def _prepare_variables(self):
        super()._prepare_variables()
        self._vertical_event_list = [[] for _ in range(len(self.VERTICAL_GROUP))]

    def _exec_landau_fit(self):
        self._fitted_MPV_s = np.zeros(64, dtype=np.float)
        for ch in range(64):
            self.fit_by_landau(ch)
            self._fitted_MPV_s[ch] = self._f_landau[ch].GetParameter(1)

    def _search_vertical_mu_event(self, vertical_group_index, VadcHigh, i_event):
        target_vertical_channels = np.array(self.VERTICAL_GROUP[vertical_group_index])
        n_vertical_hit = np.sum(self._fitted_MPV_s[target_vertical_channels] < VadcHigh[target_vertical_channels])
        if n_vertical_hit > self._n_hit_required:
            self._vertical_event_list[vertical_group_index].append(i_event)

    def _search_all_vertical_mu_event(self):
        for i_event in tqdm(range(self._n_event), desc="search vertical"):
            self.GetEntry(i_event)
            VadcHigh_np = np.array(self.VadcHigh)
            for vertical_group_index in range(len(self.VERTICAL_GROUP)):
                self._search_vertical_mu_event(vertical_group_index, VadcHigh_np, i_event)
        self._vertical_event_list_sum = set()
        for vertical_events in self._vertical_event_list:
            for vertical_event in vertical_events:
                self._vertical_event_list_sum.add(vertical_event)
        self._vertical_event_list_sum = sorted(list(self._vertical_event_list_sum))

    def _set_threshold(self, ch, adc_th):
        super().set_threshold(ch, adc_th)

    def determine_hits(self, adc_threshold_s: list):
        self._is_hit = []
        if type(adc_threshold_s) != list:
            print("invalid arg type")
        if len(adc_threshold_s) != 64:
            print("invalid arg list length")
        for ch in range(64):
            self._set_threshold(ch, adc_threshold_s[ch])
        self.determine_hit_by_landau_fit()

    def calc_effeciency(self, ch_target):
        print("this method is unused")
        exit()

    def _calc_effeciency(self, ch_target):
        for i, group in enumerate(self.VERTICAL_GROUP):
            if ch_target in group:
                vertical_group_index = i
        n_hitted = 0
        for i_event in self._vertical_event_list[vertical_group_index]:
            if self._is_hit[i_event][ch_target]:
                n_hitted += 1
        if len(self._vertical_event_list[vertical_group_index]) == 0:
            return -1
        eff = n_hitted / len(self._vertical_event_list[vertical_group_index])
        return eff

    def get_64ch_effeciency(self):
        ret = []
        for ch in tqdm(range(64), desc="calc eff"):
            ret.append(self._calc_effeciency(ch))
        return ret

    def determine_hit_by_landau_fit(self):
        self._is_hit = np.zeros([self._n_event, 64], dtype=np.int)
        is_hit = np.zeros(64, dtype=np.bool)
        self._threshold = np.array(self._threshold)
        for i_event in tqdm(self._vertical_event_list_sum):
            self.GetEntry(i_event)
            VadcHigh = np.array(self.VadcHigh)
            is_hit = self._threshold < VadcHigh
            self._is_hit[i_event] = np.array(is_hit, dtype=np.bool)
#from pyroot_easiroc import *from tqdm import tqdm
import uproot
import numpy as np
import pickle
from . import calibrationUtils as util


class HitArrayGen:
    CHANNELS_UPSIDE = np.array([[i for i in range(0 + (4 * layer), 4 + (4 * layer))] for layer in range(8)])
    CHANNELS_DOWNSIDE = np.fliplr(np.array([[i for i in range(60 - (4 * layer), 64 - (4 * layer))] for layer in range(8)]))

    def __init__(self, rootfile_path):
        """
        NOT compatible Regular expression (ex. *, [0-9])
        """
        self._check_uproot_version()
        self._check_rootfile_extension(rootfile_path)
        self._load_rootfile(rootfile_path)
        print("{}\n {} event loaded.".format(self._rootfile_path, self._n_event))
        self._prepare_variables()

    def _check_uproot_version(self):
        if not int(uproot.__version__[0]) == 4:
            print("class HitArrayGen require uproot ver 4, current ver {}".format(uproot.__version__))
            exit()

    def _check_rootfile_extension(self, rootfile_path):
        if not rootfile_path[-5:] == ".root":
            print("rootfile_path must be finish '.root'")

    def _prepare_variables(self):
        self._thresholds = np.array([1200 for _ in range(64)])

    def _load_rootfile(self, rootfile_path):
        self._rootfile_path = rootfile_path
        with uproot.open(self._rootfile_path) as file:
            self._tree = file["tree"]
            self._VadcHigh = self._tree["VadcHigh"].array(library="np")
        self._n_event = len(self._VadcHigh)

    def set_threshold(self, ch, adc_th):
        self._thresholds[ch] = adc_th

    def generate_hit_array(self):
        # ch order => upside, downside order
        Vadc_upside = self._VadcHigh[:, self.CHANNELS_UPSIDE]
        Vadc_downside = self._VadcHigh[:, self.CHANNELS_DOWNSIDE]
        # in order to make crossed map, duplicate vector
        self._Vadc_upside_map = np.tile(Vadc_upside, (1, 1, 4)).reshape(self._n_event, 8, 4, 4).transpose(0, 1, 3, 2)
        self._Vadc_downside_map = np.tile(Vadc_downside, (1, 1, 4)).reshape(self._n_event, 8, 4, 4)
        # ch order => upside, downside order too (threshold)
        self._thresholds_upside_map = np.tile(self._thresholds[self.CHANNELS_UPSIDE], (1, 4)).reshape(8, 4, 4).transpose(0, 2, 1)
        self._thresholds_downside_map = np.tile(self._thresholds[self.CHANNELS_DOWNSIDE], (1, 4)).reshape(8, 4, 4)
        # make crossed map
        self._hit_array = (self._thresholds_upside_map < self._Vadc_upside_map) * (self._thresholds_downside_map < self._Vadc_downside_map)

    def get_hit_array(self):
        return self._hit_array
from . import calibrationUtils
import ROOT as r
from copy import copy
from tqdm import tqdm
import os

class ExtendedTTree(r.TChain):
    _root_file_name = ""
    _hists_VadcHigh = ["unko" for _ in range(64)]
    _funcs_pedestal_fit = ["unko" for _ in range(64)]
    _n_events = None

    def __init__(self, tree_name, tree_title, root_file_name):
        super().__init__(tree_name, tree_title)
        self._root_file_name = root_file_name
        self.Add(self._root_file_name)
        self._n_events = self.GetEntries()

        for i in tqdm(range(64), desc="[tree->hist]"):
            self.construct_hist_VadcHigh(i)

    def construct_hist_VadcHigh(self, ch):
        hist = r.TH1D(
            "VadcHigh {}ch".format(ch),
            "VadcHigh {}ch;ADC;events".format(ch),
            4096, 0, 4096
        )
        # FOR DEBUG
        for i in range(1000):
        #for i in tqdm(range(self._n_events)):
            self.GetEntry(i)
            hist.Fill(self.VadcHigh[ch])

        self._hists_VadcHigh[ch] = hist

    def fit_pedestal(self, ch, xmin, xmax):
        func = r.TF1("func ch{}".format(ch), "gaus", xmin, xmax)
        self._hists_VadcHigh[ch].Fit(func, "R")
        self._funcs_pedestal_fit[ch] = func

    def fit_pedestal_auto_range(self, ch):
        xmin = 0
        xmax = 4096
        while self._hists_VadcHigh[ch].GetBinContent(xmin) == 0: xmin += 1
        while self._hists_VadcHigh[ch].GetBinContent(xmax) == 0: xmax -= 1
        self._hists_VadcHigh[ch].GetXaxis().SetRangeUser(xmin, xmax)
        self.fit_pedestal(ch, xmin, xmax)

    def get_pedestal_fit_params(self, ch):
        func = self._funcs_pedestal_fit[ch]
        ret = {
            "constant": (func.GetParameter(0), func.GetParError(0)),
            "mean": (func.GetParameter(1), func.GetParError(1)),
            "sigma": (func.GetParameter(2), func.GetParError(2))
        }
        return ret

    def save_hist_VadcHigh(self, ch, save_dir_path="./"):
        dir_name = "img_" + self._root_file_name.split('/')[-1].replace(".root", "").replace("root", "") + '/'
        os.mkdir(save_dir_path + dir_name)
        hist = self._hists_VadcHigh[ch]
        canvas = r.TCanvas()
        hist.Draw()
        canvas.SaveAs(save_dir_path + dir_name + "VadcHigh_{}ch.png".format(ch))

    def save_hists_VadcHigh(self, save_dir_path="./"):
        for ch in tqdm(range(64), desc="[img saving]"):
            self.save_hist_VadcHigh(ch, save_dir_path)from typing import List
import numpy as np
import networkx as nx
from .TrackReconstructorBase import TrackReconstructorBase


class PionTrackReconstructor(TrackReconstructorBase):
    def __init__(self, rootfile_path: str, threshold_s: List[int]) -> None:
        super().__init__(rootfile_path, threshold_s)
        self._preselect()

    def _preselect(self):
        """
        1番上のlayerで1つだけ鳴ったイベントを集める
        """
        self._layer_n_hits = np.sum(self._hit_array, axis=(2, 3))
        layer_ref = np.array([0, 0, 0, 0, 0, 0, 0, 1])
        self._preselected_events_index = np.where(layer_ref[-1] == self._layer_n_hits.T[-1])[0]
        self._preselected_events_array = self._hit_array[self._preselected_events_index]
        print("preselected, {:.2f}% remain".format(
            100 * self._preselected_events_array.shape[0] / self._hit_array.shape[0]
        ))

    def write_preselected_event(self, i):
        self.show(self._preselected_events_index[i])

    def print_preselected_event(self, i):
        print(self._hit_array[self._preselected_events_index[i]])
import os
from .TrackReconstructorBase import TrackReconstructorBase
import numpy as np
from tqdm import tqdm
import itertools
from scipy import optimize

class FitPointMuonTrackReconstructor(TrackReconstructorBase):
    def __init__(self, rootfile_path, threshold_s) -> None:
        super().__init__(rootfile_path, threshold_s)
        self._hit_array = self._hit_array_gen.get_hit_array()
    
    # 6層以上なったイベントのみを見る
    def _pre_cut_threshold_layer(self):
        self._layer_n_hit = np.sum(self._hit_array, axis = (2,3))
        self._hit_layer_number = np.count_nonzero(self._layer_n_hit, axis=1)
        self._threshold_layer_number = 6
        self._pre_cut_index = np.where(self._hit_layer_number >= self._threshold_layer_number)[0]
        print("preselected, {:.2f}% remain".format(
            100 * self._pre_cut_index.shape[0] / self._hit_array.shape[0]
        ))

    # 一番上の層で1hitのeventを保存
    def _select_top_layer_hit_event (self):
        self._select_top_layer_hit_event_index = []
        # hit_arrayは上から最下層の情報なので 8 - origin_layer
        self._threshold_hit = 1
        for i_event in tqdm(self._pre_cut_index):
            if (np.all(self._layer_n_hit[i_event][7] == self._threshold_hit)):
                self._select_top_layer_hit_event_index.append(i_event)

        print("select_top_layer_hit_event is ",len(self._select_top_layer_hit_event_index))
        print("selected select_top_layer_hit_event, {:.2f}% remain".format(
            100 * len(self._select_top_layer_hit_event_index) / self._hit_array.shape[0]
        )) 
        
    #　一番下の層で1hitのイベントを保存
    def _select_bottom_layer_hit_event (self):
        self._select_bottom_layer_hit_event_index = []
        # hit_arrayは上から最下層の情報なので 8 - origin_layer
        self._threshold_hit = 1
        for i_event in tqdm(self._select_top_layer_hit_event_index):
            if (np.all(self._layer_n_hit[i_event][0] == self._threshold_hit)):
                self._select_bottom_layer_hit_event_index.append(i_event)

        print("select_top & bottom layer_hit_event is ",len(self._select_bottom_layer_hit_event_index))
        print("selected select top & bottom layer_hit_event, {:.2f}% remain".format(
            100 * len(self._select_bottom_layer_hit_event_index) / self._hit_array.shape[0]
        )) 

    # 2層連続でなっていないイベントを排除
    def _cut_non_2_layer_continue_event(self):
        self._cut_non_2_layer_continue_event_index = []
        for i_event in tqdm(self._select_bottom_layer_hit_event_index):
            for i_layer in range(8):
                if (i_layer == 7):
                    self._cut_non_2_layer_continue_event_index.append(i_event)
                if (np.all(self._layer_n_hit[i_event][i_layer:i_layer+1] == 0)):
                    break

        print("cut non 2 layer continue event is ",len(self._cut_non_2_layer_continue_event_index))
        print("selected select top & bottom layer_hit_event, {:.2f}% remain".format(
            100 * len(self._cut_non_2_layer_continue_event_index) / self._hit_array.shape[0]
        )) 

    # 3-7層で2hitあるイベントを保存
    def _select_2hits(self):
        self._select_2hits_index = [] 
        for i_event in tqdm(self._cut_non_2_layer_continue_event_index):
            if(np.any(self._layer_n_hit[i_event][1:6] >= 2)):
                self._select_2hits_index.append(i_event)
        print("select 2 hits event is ",len(self._select_2hits_index))
        print("selected 2 hits event, {:.2f}% remain".format(
            100 * len(self._select_2hits_index) / self._hit_array.shape[0]
        ))
    

    # 0,1層で2層が連続してhit
    def _select_top_layer_continue_hits_event (self):
        self._select_top_layer_continue_hits_event_index = []
        for i_event in tqdm(self._select_2hits_index):
            for i,j in itertools.product(range(4), range(4)):
                if (self._hit_array[i_event][6][i][j] == 1):
                    if(i == 0):
                        hit_slice = self._hit_array[i_event][7][0:i+2, j-1:j+2]
                        if(np.any(hit_slice == 1)):
                            self._select_top_layer_continue_hits_event_index.append(i_event)
                            break
                        continue
                    
                    if(j == 0):
                        hit_slice = self._hit_array[i_event][7][i-1:i+2, 0:j+2]
                        if(np.any(hit_slice == 1)):
                            self._select_top_layer_continue_hits_event_index.append(i_event)
                            break
                        continue
                    
                    else:
                        hit_slice = self._hit_array[i_event][7][i-1:i+2, j-1:j+2]
                        if(np.any(hit_slice == 1)):
                            self._select_top_layer_continue_hits_event_index.append(i_event)
                            break
                        continue

        print("0,1 layer continue hits event is ",len(self._select_top_layer_continue_hits_event_index))
        print("selected 0,1th layer continue hits event , {:.2f}% remain".format(
            100 * len(self._select_top_layer_continue_hits_event_index) / self._hit_array.shape[0]
        ))   
    
    # 3-7層でいずれかが2層連続してhit
    def _select_continue_2_hits(self):
        self._select_continue_2_hits_index = []
        for i_event in tqdm(self._select_top_layer_continue_hits_event_index):
            for i_layer, i,j in itertools.product(range(2,7), range(4), range(4)):    
                if (self._hit_array[i_event][i_layer][i][j] == 1):
                    if(i == 0):
                        hit_slice = self._hit_array[i_event][i_layer + 1][0:i+2, j-1:j+2]
                        if(np.any(hit_slice == 1)):
                            if (np.all(self._layer_n_hit[i_event][i_layer:i_layer+2] >= 2) ):
                                self._select_continue_2_hits_index.append(i_event)
                                break
                        continue
                    
                    if(j == 0):
                        hit_slice = self._hit_array[i_event][i_layer + 1][i-1:i+2, 0:j+2]
                        if(np.any(hit_slice == 1)):
                            if (np.all(self._layer_n_hit[i_event][i_layer:i_layer+2] >= 2) ):
                                self._select_continue_2_hits_index.append(i_event)
                                break
                        continue
                    
                    else:
                        hit_slice = self._hit_array[i_event][i_layer + 1][i-1:i+2, j-1:j+2]
                        if(np.any(hit_slice == 1)):
                            if (np.all(self._layer_n_hit[i_event][i_layer:i_layer+2] >= 2) ):
                                self._select_continue_2_hits_index.append(i_event)
                                break
                        continue
        
        print("all cut is ",len(self._select_continue_2_hits_index))
        print("selected all cut, {:.2f}% remain".format(
            100 * len(self._select_continue_2_hits_index) / self._hit_array.shape[0]
        ))   

    # 層の上と下でhitした菱形の中心値をとってくる
    def _get_fit_point(self, i_event):
        self._i_event = i_event
        self._top_layer = 7
        self._bottom_layer = 0

        # [i,j]で返される
        self._hit_top_point = np.where(self._hit_array[self._i_event][self._top_layer] == 1)
        self._hit_bottom_point = np.where(self._hit_array[self._i_event][self._bottom_layer] == 1)

        # [i]で返されるのでこの後にintに変える
        self._hit_top_i = self._hit_top_point[0]
        self._hit_top_j = self._hit_top_point[1]
        self._hit_bottom_i = self._hit_bottom_point[0]
        self._hit_bottom_j = self._hit_bottom_point[1]
        
        self._point_top = self.get_point(self._top_layer, self._hit_top_i[0], self._hit_top_j[0])
        self._point_bottom = self.get_point(self._bottom_layer, self._hit_bottom_i[0], self._hit_bottom_j[0])
        
        print("top:",self._point_top,"bottom:", self._point_bottom)

        # 5点以上必要だが2点しかフィットしないため無理やり6点を作る
        self._x = np.array([self._point_top[0],self._point_top[0],self._point_top[0],
                            self._point_bottom[0],self._point_bottom[0],self._point_bottom[0]])
        
        self._y = np.array([self._point_top[1],self._point_top[1],self._point_top[1],
                            self._point_bottom[1],self._point_bottom[1],self._point_bottom[1]])

        self._z = np.array([self._point_top[2],self._point_top[2],self._point_top[2],
                            self._point_bottom[2],self._point_bottom[2],self._point_bottom[2]])
        # ハイパーパラメータを初期化
        self._param = [0, 0, 0, 0, 0]
    
    # フィッティングする関数を作成
    def fiting_func(self,param,x,y,z):
        residual = z - (param[0]*x**2 + param[1]*y**2 + param[2]*x + param[3]*y + param[4])
        return residual

    # フィットした直線の方程式をとってくる
    def _get_fit_line_equition(self, i_event):
        x = np.array([self._point_top[0],self._point_top[0],self._point_top[0],
                            self._point_bottom[0],self._point_bottom[0],self._point_bottom[0]])
        y = np.array([self._point_top[1],self._point_top[1],self._point_top[1],
                            self._point_bottom[1],self._point_bottom[1],self._point_bottom[1]])
        z = np.array([self._point_top[2],self._point_top[2],self._point_top[2],
                            self._point_bottom[2],self._point_bottom[2],self._point_bottom[2]])

        # 最小二乗法を実装
        self._optimised_param =  optimize.leastsq(self.fiting_func, self._param, args=(x, y, z))

        print(self._optimised_param)

        # フィッティングする関数を求める
        self._a = self._optimised_param[0][0]
        self._b = self._optimised_param[0][1]
        self._c = self._optimised_param[0][2]
        self._d = self._optimised_param[0][3]
        self._e = self._optimised_param[0][4]

        print("a:",self._optimised_param[0][0])
        print("b:",self._optimised_param[0][1])
        print("c:",self._optimised_param[0][2])
        print("d:",self._optimised_param[0][3])
        print("e:",self._optimised_param[0][4])

        self._line_equition = self._a * self._x**2 + self._b * self._y**2 + self._c * self._x + self._d * self._y + self._e
        
    def write_fig(self, i_event):
        filename_short = self._rootfile_path.split('/')[-1]
        self._save_directory_path = "img_{}_{}_layer_hits_{}_hits_top_bottom_cut".format(filename_short, self._threshold_layer_number, self._threshold_hit)
        os.makedirs(self._save_directory_path, exist_ok=True)
        self._fig.write_image(self._save_directory_path + "/event{}.png".format(i_event), scale=10)
        print(self._save_directory_path + "/event{}.png".format(i_event))
        self._fig.write_html(self._save_directory_path + "/event{}.html".format(i_event))
        print(self._save_directory_path + "/event{}.html".format(i_event))from typing import List, Dict, Tuple
import ROOT as r
from tqdm import tqdm
from . import calibrationUtils as util
import os


class CalibrationData:

    def __init__(self, image_dir_path: str, MPPC_high_voltage: str) -> None:
        self._image_dir_path: str = image_dir_path
        self._HV: str = MPPC_high_voltage
        self._hists_VadcHigh: List[r.TH1D] = [None for _ in range(64)]
        self._fitted_adc_means: List[List[float]] = [None for _ in range(64)]
        self._fitted_adc_mean_errors: List[List[float]] = [None for _ in range(64)]

    def set_hist(self, detector_ch: int, cal_root_file_path: str, cal_ch: int) -> None:
        hist = util.getHistMPPC(cal_root_file_path, cal_ch)
        hist.SetTitle("{0} [{1}ch];ADC;Events".format(cal_root_file_path.split('/')[-1], detector_ch))
        self._hists_VadcHigh[detector_ch] = hist

    def fit_multi_gaus(
        self,
        ch,
        peak_search_range=(0, 1500),
        fitting_range=(0, 1500),
        peak_search_sigma=10
    ) -> None:
        hist = self._hists_VadcHigh[ch]
        # determine hist showing range
        xmin = 0
        xmax = 4096
        while self._hists_VadcHigh[ch].GetBinContent(xmin) == 0:
            xmin += 1
        while self._hists_VadcHigh[ch].GetBinContent(xmax) == 0:
            xmax -= 1
        # fit
        ret_adc_means, ret_adc_mean_errors = util.getFittedParams(
            hist,
            peak_search_range,
            fitting_range,
            (xmin, xmax),
            peak_search_sigma
        )
        # set to member variable & save hist as image
        self._fitted_adc_means[ch] = ret_adc_means
        self._fitted_adc_mean_errors[ch] = ret_adc_mean_errors
        self.save_hist_as_png(ch)

    def save_hist_as_png(self, ch):
        canvas = r.TCanvas()
        self._hists_VadcHigh[ch].Draw()
        canvas.SaveAs("{0}/{1}/hist.png".format(self._image_dir_path, ch))


class CalibrationDatas:

    def __init__(self) -> None:
        self._calbDatas: Dict[str, CalibrationData] = {}
        self._HVs: List[str] = []
        self._calb_line_TGraphs: Dict[str, List[r.TGraphErrors]] = {}
        self._calb_line_TF1s: Dict[str, List[r.TF1]] = {}
        self._calb_line_TCanvases: Dict[str, List[r.TCanvas]] = {}
        self._HV_one_photon_TGraphs: List[r.TGraphErrors] = [None for _ in range(64)]
        self._HV_one_photon_TF1s: List[r.TF1] = [None for _ in range(64)]
        self._pedestal_data_path: str = None
        self._pedestal_adc_means: List[float] = [None for _ in range(64)]
        self._pedestal_adc_mean_errors: List[float] = [None for _ in range(64)]
        self._initial_photon_number_s: Dict[str, List[int]] = {}

    def set_calb_data(self, img_dir_path: str, HV: str) -> None:
        self._calbDatas[HV] = CalibrationData(img_dir_path, HV)
        self._HVs.append(HV)
        self._calb_line_TCanvases[HV] = [None for _ in range(64)]
        self._calb_line_TGraphs[HV] = [None for _ in range(64)]
        self._calb_line_TF1s[HV] = [None for _ in range(64)]
        self._initial_photon_number_s[HV] = [None for _ in range(64)]
        self.make_dirs()

    def get_calb_data(self, HV: str) -> CalibrationData:
        return self._calbDatas[HV]

    def fit_adc_nphoton_line(self, HV, ch, initial_photon_num):
        # init graph
        n_points = len(self._calbDatas[HV]._fitted_adc_means[ch])
        photon_nums = [initial_photon_num + i for i in range(n_points)]
        photon_num_errors = [0 for _ in range(n_points)]
        adc_means = self._calbDatas[HV]._fitted_adc_means[ch]
        adc_mean_errors = self._calbDatas[HV]._fitted_adc_mean_errors[ch]
        graph = util.TPGraphErrors(
            n_points,
            photon_nums,
            adc_means,
            photon_num_errors,
            adc_mean_errors
        )
        graph.SetTitle("{}ch;Photon Number;ADC Value".format(ch))
        graph.SetMarkerStyle(8)
        graph.SetMarkerSize(1)

        # init liner function for fitting and fit
        f_fit = r.TF1("f_liner", "[0]*x + [1]", 0, 20)
        graph.Fit(f_fit, "R")

        # init axis for Tgraph
        photon_num_range = (0, photon_nums[-1] + 1)
        adc_range = tuple(map(f_fit.Eval, photon_num_range))
        axis = r.TH2D(
            "axis", "{}ch;Photon Number;ADC Value".format(ch),
            0, *photon_num_range,
            0, *adc_range
        )
        axis.SetStats(0)

        # draw to canvas
        canvas = r.TCanvas()
        axis.Draw("AXIS")
        graph.Draw("P SAME")

        # set to member function & save canvas as png
        self._calb_line_TCanvases[HV][ch] = canvas
        self._calb_line_TGraphs[HV][ch] = graph
        self._calb_line_TF1s[HV][ch] = f_fit
        self.save_calb_line_TCanvas(HV, ch)

    def fit_all_adc_nphoton_line(self):
        for HV in self._HVs:
            for ch in range(64):
                initial_photon_num = self._initial_photon_number_s[HV][ch]
                self.fit_adc_nphoton_line(HV, ch, initial_photon_num)

    def save_calb_line_TCanvas(self, HV, ch):
        save_str = "{0}/{1}/graph_photon_adc.png".format(
            self._calbDatas[HV]._image_dir_path,
            ch
        )
        self._calb_line_TCanvases[HV][ch].SaveAs(save_str)

    def fit_HV_one_photon(self, ch):
        # fetch graph attr
        n_points = len(self._HVs)
        HVs = []
        HV_errors = []
        one_photon_adc_widthes = []
        one_photon_adc_width_errors = []
        for HV in self._HVs:
            HVs.append(float(HV))
            one_photon_adc_widthes.append(
                self._calb_line_TF1s[HV][ch].GetParameter(0)
            )
            HV_errors.append(0)
            one_photon_adc_width_errors.append(
                self._calb_line_TF1s[HV][ch].GetParError(0)
            )

        # init graph
        graph = util.TPGraphErrors(
            n_points,
            HVs,
            one_photon_adc_widthes,
            HV_errors,
            one_photon_adc_width_errors
        )
        graph.SetTitle("{}ch;MPPC HV [V];ADC/One Photon".format(ch))
        graph.SetMarkerStyle(8)
        graph.SetMarkerSize(1)

        # init liner function for fitting and fit
        f_fit = r.TF1("f_liner", "[0]*x + [1]", 0, 60)
        graph.Fit(f_fit, "R")

        # set to class member variable & save graph as png
        self._HV_one_photon_TGraphs[ch] = graph
        self._HV_one_photon_TF1s[ch] = f_fit
        self.save_HV_one_photon_TGraph(ch)

    def save_HV_one_photon_TGraph(self, ch):
        save_str = "{0}/{1}/HV_one_photon_TGraph.png"
        canvas = r.TCanvas()
        self._HV_one_photon_TGraphs[ch].Draw("AP")
        for HV in self._HVs:
            canvas.SaveAs(save_str.format(self._calbDatas[HV]._image_dir_path, ch))

    def make_dirs(self):
        for HV in self._HVs:
            os.makedirs(self._calbDatas[HV]._image_dir_path, exist_ok=True)
            for i in range(64):
                os.makedirs("{0}/{1}".format(self._calbDatas[HV]._image_dir_path, i), exist_ok=True)

    def set_pedestal_data(self, pedestal_data_path):
        self._pedestal_data_path = pedestal_data_path
        hists = [util.getHistMPPC(self._pedestal_data_path, ch) for ch in range(64)]
        funcs = [r.TF1("", "gaus", 0, 4096) for _ in range(64)]
        for ch in range(64):
            hists[ch].Fit(funcs[ch], "R")
            self._pedestal_adc_means[ch] = funcs[ch].GetParameter(1)
            self._pedestal_adc_mean_errors[ch] = funcs[ch].GetParError(1)

    def determine_initial_photon_number(self, HV, ch):
        fitted_means = self._calbDatas[HV]._fitted_adc_means[ch]
        pedestal_mean = self._pedestal_adc_means[ch]
        diff_ped_to_ini = fitted_means[0] - pedestal_mean
        aprox_width = fitted_means[1] - fitted_means[0]
        initial_photon_number = round(diff_ped_to_ini / aprox_width)
        self._initial_photon_number_s[HV][ch] = initial_photon_number

    def determine_all_initial_photon_number(self):
        for HV in self._HVs:
            for ch in range(64):
                self.determine_initial_photon_number(HV, ch)

    def print_fitted_pedestal(self):
        for HV in self._HVs:
            print("========== {}V ==========".format(HV))
            for ch in range(64):
                print(self._calb_line_TF1s[HV][ch].Eval(0))

    def get_HV_from_one_photon(self, ch, one_photon_width) -> float:
        a = self._HV_one_photon_TF1s[ch].GetParameter(0)
        b = self._HV_one_photon_TF1s[ch].GetParameter(1)
        return (one_photon_width - b) / a

    def make_yml_InputDAC(self, target_width) -> None:
        HV_target_s = [self.get_HV_from_one_photon(ch, target_width) for ch in range(64)]
        HV_ref = HV_target_s[0]
        HV_diff_s = [HV_ref - HV for HV in HV_target_s]
        DAC_bit_s = [256 + 128 + int(HV_diff / (4.5/256)) for HV_diff in HV_diff_s]
        out_str = "# setHV {}\n".format(HV_ref)
        out_str += "---\n"
        out_str += "EASIROC1:\n"
        out_str += "  Input 8-bit DAC:\n"
        for ch in range(0, 32):
            out_str += "  - {}\n".format(DAC_bit_s[ch])
        out_str += "EASIROC2:\n"
        out_str += "  Input 8-bit DAC:\n"
        for ch in range(32, 64):
            out_str += "  - {}\n".format(DAC_bit_s[ch])

        with open("InputDAC.yml", 'w') as f:
            f.write(out_str)

    def set_InputDAC_mesurement_data(self, d: Dict[float, List[List[List[float]]]]) -> None:
        """
        d[setHV][ch][0] = [DAC_value (256-511)]\n
        d[setHV][ch][1] = [DAC_voltage]
        """
        self._InputDAC_mesurement_data = d
        os.makedirs("InputDAC_fit", exist_ok=True)

    def set_setHV_to_realHV(self, d: Dict[float, float]) -> None:
        self._setHV_to_realHV = d

    def fit_InputDAC_vaule_voltage_line(
        self,
        setHV: float,
        ch: int,
        fit_range: Tuple[float] = (258, 350)
    ) -> None:
        n_points = len(self._InputDAC_mesurement_data[setHV][ch][0])
        MPPC_HVs = [self._setHV_to_realHV[setHV] for _ in range(n_points)]
        truth_HVs = [
            MPPC_HV - DAC_V
            for MPPC_HV, DAC_V in zip(MPPC_HVs, self._InputDAC_mesurement_data[setHV][ch][1])
        ]
        g = util.TPGraphErrors(
            n_points,
            self._InputDAC_mesurement_data[setHV][ch][0],
            truth_HVs,
            [0 for _ in range(n_points)],
            [0 for _ in range(n_points)]
        )
        f = r.TF1(
            "InputDAC_vaule_voltage_{}_{}".format(setHV, ch),
            "[0]*x+[1]",
            fit_range[0],
            fit_range[1]
        )
        c = r.TCanvas()
        g.Fit(f, "R")
        g.Draw("AP")
        c.SaveAs("InputDAC_fit/setHV{0}_ch{1}.png".format(setHV, ch))
from . import TrackReconstructorBase
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import os
from PIL import Image, ImageChops


def crop_bg(img):
    bw_img = img.convert(mode='1', dither=None)
    bw_inv_img = ImageChops.invert(bw_img)
    crop_range = bw_inv_img.convert('RGB').getbbox()
    crop_img = img.crop(crop_range)
    return crop_img


def get_concat_v(imgs):
    total_heigt = sum([img.height for img in imgs])
    dst = Image.new('RGB', (imgs[0].width, total_heigt))
    dst.paste(imgs[0], (0, 0))
    current_height = 0
    for i in range(1, len(imgs)):
        current_height += imgs[i-1].height
        dst.paste(imgs[i], (0, current_height))
    return dst


class LayerVisualizer:
    def __init__(self, tr: TrackReconstructorBase):
        self._tr = tr
        self._init_fig()
        self._draw_layers()

    def _init_fig(self):
        self._fig = make_subplots(
            rows=8, cols=1,
            specs=[
                [{'type': 'mesh3d'}] for _ in range(8)
            ]
        )
        self._fig.update_layout(height=3600*8, width=3000)
        self._fig.update_scenes(
            xaxis=dict(nticks=1, range=self._tr.SHOWING_RANGE,),
            yaxis=dict(nticks=1, range=self._tr.SHOWING_RANGE,),
            zaxis=dict(nticks=1, range=self._tr.SHOWING_RANGE,),
            xaxis_visible=False,
            yaxis_visible=False,
            zaxis_visible=False,
            camera_projection_type="orthographic",
            camera=dict(eye=dict(x=0, y=0, z=1))
        )

    def _draw_layers(self):
        self._draw_a_layer(7, 1, 1)
        self._draw_a_layer(6, 2, 1)
        self._draw_a_layer(5, 3, 1)
        self._draw_a_layer(4, 4, 1)
        self._draw_a_layer(3, 5, 1)
        self._draw_a_layer(2, 6, 1)
        self._draw_a_layer(1, 7, 1)
        self._draw_a_layer(0, 8, 1)

    def _draw_a_layer(self, i_layer, row, col):
        for i_pix in range(i_layer*16, i_layer*16 + 16):
            x, y, z = self._tr.pixels[i_pix].get_vertices().T
            i, j, k = self._tr.pixels[i_pix].get_faces().T
            self._fig.add_trace(
                go.Mesh3d(
                    x=x, y=y, z=z, i=i, j=j, k=k,
                    color=self._tr.pix_color[i_pix],
                ),
                row=row, col=col
            )

    def write_fig(self):
        os.makedirs("img_{}".format(self._tr._filename_short), exist_ok=True)
        self._save_file_name = "img_{}/layer_{}.png".format(self._tr._filename_short, self._tr._i_event)
        self._fig.write_image("tmp.png")
        self._crop_image()

    def _crop_image(self):
        input_image = Image.open("tmp.png")
        cropped_images = []
        for i in range(0, 8):
            cropped_image = input_image.crop((
                0,
                int(input_image.height * (i/8)),
                input_image.width,
                int(input_image.height * ((i+1)/8))
            ))
            cropped_images.append(crop_bg(cropped_image))
        cropped_concatted_image = get_concat_v(cropped_images)
        cropped_concatted_image.save(self._save_file_name)
        print("{} saved".format(self._save_file_name))
from setuptools import setup, find_packages

setup(
    name='pyroot_easiroc',
    version='0.1',
    packages=find_packages()
)