# This is a template file used by _Stencil_CP._gen_code() to JIT-compile critical sections.
# Dollar-enclosed variables are replaced at runtime.

import numba.cuda

f_dev_flags = dict(
    device=True,
    inline=True,
    fastmath=True,
    opt=True,
    cache=True,
)


@numba.cuda.jit(**f_dev_flags)
def unravel_offset(
    offset,  # int
    shape,  # (D+1,)
):
    # Compile-time-equivalent of np.unravel_index(offset, shape, order='C').
    ${unravel_spec}
    return idx


@numba.cuda.jit(**f_dev_flags)
def full_overlap(
    idx,  # (D+1,)
    dimensions,  # (D+1,)
):
    # Computes intersection of:
    # *      idx[0]  < dimensions[0]
    # * 0 <= idx[1:] - kernel_center
    # *      idx[1:] - kernel_center <= dimensions[1:] - kernel_width
    if not (idx[0] < dimensions[0]):
        # went beyond stack-dim limit
        return False

    kernel_width = ${kernel_width}
    kernel_center = ${kernel_center}
    for i, w, c, n in zip(
        idx[1:],
        kernel_width,
        kernel_center,
        dimensions[1:],
    ):
        if not (0 <= i - c <= n - w):
            # kernel partially overlaps ROI around `idx`
            return False

    # kernel fully overlaps ROI around `idx`
    return True


f_flags = dict(
    device=False,
    inline=False,
    fastmath=True,
    opt=True,
    cache=True,
    # max_registers=None,  # see .apply() notes
)


@numba.cuda.jit(
    func_or_sig="${signature}",
    **f_flags,
)
def f_jit(arr, out):
    # arr: (N_stack, M1,...,MD)
    # out: (N_stack, M1,...,MD)
    offset = numba.cuda.grid(1)
    if offset < arr.size:
        idx = unravel_offset(offset, arr.shape)
        if full_overlap(idx, arr.shape):
            out[idx] = ${stencil_spec}
        else:
            out[idx] = 0
