Metadata-Version: 2.4
Name: mm-kermac
Version: 1.0.0
Summary: Dynamically compiled hyper semirings for Pytorch using PTX Inject and Stack PTX
Project-URL: Homepage, https://github.com/MetaMachines/mm-kermac-py
Project-URL: Repository, https://github.com/MetaMachines/mm-kermac-py.git
Author-email: MetaMachines LLC <contact@metamachines.co>
License: MIT
License-File: LICENSE
Classifier: Development Status :: 5 - Production/Stable
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: C
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Scientific/Engineering :: Mathematics
Classifier: Topic :: Software Development :: Compilers
Classifier: Topic :: Software Development :: Libraries
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Classifier: Topic :: System :: Hardware
Requires-Python: >=3.8
Requires-Dist: cuda-toolkit[cccl]
Requires-Dist: mm-ptx>=1.0.0
Requires-Dist: numpy>=1.21
Requires-Dist: platformdirs>=3.0
Requires-Dist: torch>=2.0
Provides-Extra: cu12
Requires-Dist: cuda-core[cu12]>=0.5.0; extra == 'cu12'
Provides-Extra: cu13
Requires-Dist: cuda-core[cu13]>=0.5.0; extra == 'cu13'
Description-Content-Type: text/markdown

# mm-kermac
> Dynamically compiled hyper semirings for Pytorch using PTX Inject and Stack PTX

This repo provides routines for Semiring and Semiring gradient Tensor operations for PyTorch. It also provides a DSL for writing your own custom Semiring and Semiring gradient routines that may include hyperparameters passed in to the kernel. These hyperparameters can either be single value tensors, single value tensors broadcast to a batch of tensors or a vector of batched hyperparameters applied to a batch of tensors.

## Quickstart
```python
import torch
import mm_kermac.hyper_semiring as kermac

device = torch.device("cuda")
M, N, K = 1024, 2048, 256
x = torch.randn(M, K, device=device)
z = torch.randn(N, K, device=device)
out = torch.empty((M, N), device=device)

# GEMM
gemm = kermac.Gemm()
# First call compiles and loads the module (cached per device).
gemm(x=x, z=z, out=out)

# L2 cdist
norm_l2 = kermac.NormL2()
norm_l2(x=x, z=z, out=out)

# Fractional p via hyper parameters
p = 1.3
norm_lp = kermac.NormLp(epsilon=0.0)
p_inner = torch.tensor(p, device=device)
p_outer = torch.tensor(1.0 / p, device=device)
norm_lp(x=x, z=z, p_inner=p_inner, p_outer=p_outer, out=out)
```

## Installation
`mm-kermac` only supports Nvidia cards with `sm_80` or greater:
* For server cards A100 or greater, i.e. A10, H100, B100, BH200
* For consumer cards 3000 series or greater, i.e. 3070, 3090, 4090, 5090

To install, depending on your CUDA toolkit version do one of these:
```bash
pip install mm-kermac[cu12]
```
```bash
pip install mm-kermac[cu13]
```

## Zoo kernels
The zoo provides ready-to-use kernels built on `HyperSemiringKernel` and `HyperSemiringGradientKernel`.

```python
import torch
import mm_kermac.hyper_semiring as hs

device = torch.device("cuda")
x = torch.randn(1024, 256, device=device)
z = torch.randn(2048, 256, device=device)
out = torch.empty((x.size(0), z.size(0)), device=device)

gemm = hs.Gemm()
norm_l1 = hs.NormL1()
norm_l2 = hs.NormL2()
norm_lp = hs.NormLp(epsilon=0.0)

gemm(x=x, z=z, out=out)
norm_l2(x=x, z=z, out=out, try_to_align=False)
```

For gradient kernels, see `examples/hyper_semiring_gradient.py` for the expected shapes and arguments.

## Benchmarks
The benchmark compares `NormL1`, `NormL2`, and `NormLp` against `torch.cdist` for p=1.0, p=2.0, and a fractional p.

```bash
python examples/bench_hyper_semiring_cdist.py --M 2048 --N 2048 --K 256 --iters 50 --warmup 10 --p-frac 1.3
```

Sample output:
```text
Device: NVIDIA GeForce RTX 5090
M=2048 N=2048 K=256 iters=50 warmup=10
Fractional p=1.3 epsilon=0.0 try_align=False
   case  |          kermac ms |          torch ms | speedup
   p=1.0 | kermac    0.078 ms | torch    4.779 ms |  61.03x
   p=2.0 | kermac    0.080 ms | torch    0.093 ms |   1.15x
   p=1.3 | kermac    0.375 ms | torch    5.312 ms |  14.16x
```

## How it works
- `HyperSemiringKernel` renders a CUTLASS/CuTe template, injects PTX stubs generated by Stack PTX, compiles to a cubin, and caches per device and signature.
- The kernel is split into `mma_lambda` (per multiply-accumulate step) and `epilogue_lambda` (post-reduction).
- `hyper_dict` maps user names to hyperparameter tensors. Insertion order maps to `hyper0`, `hyper1`, etc in the generated PTX, and the lambdas receive a `reg_dict` keyed by those user names.
- The number of hyper parameters is user defined; the template is generated and cached separately for each count.
- Hyper parameters may be scalar tensors or length-L tensors; L is inferred from input batches and hyper tensors, so you can batch multiple p values in one call.
- Zoo kernels are thin wrappers that predefine the lambdas and build the right `hyper_dict` for you.

Custom kernel sketch:
```python
import torch
from mm_kermac import PtxInstruction
from mm_kermac.hyper_semiring import HyperSemiringKernel

device = torch.device("cuda")
x = torch.randn(1024, 256, device=device)
z = torch.randn(2048, 256, device=device)
out = torch.empty((x.size(0), z.size(0)), device=device)

kernel = HyperSemiringKernel(
    mma_lambda=lambda a, b, c, reg: [
        a,  # push a
        b,  # push b
        PtxInstruction.sub_ftz_f32,  # diff = b - a
        reg["beta"],  # push beta (dynamically 0.5 from hyper["beta"] from torch.Tensor value)
        PtxInstruction.mul_ftz_f32,  # diff *= beta
        c,  # push accumulator
        PtxInstruction.add_ftz_f32,  # acc += diff
    ],
    epilogue_lambda=lambda e, reg: [
        e,  # push accumulator
        reg["gamma"],  # push gamma (dynamically 2.0 from hyper["gamma"] from torch.Tensor value)
        PtxInstruction.mul_ftz_f32,  # scale output by gamma
    ],
)

hyper = {
    "beta": torch.tensor(0.5, device=device),
    "gamma": torch.tensor(2.0, device=device),
}
kernel(a=x, b=z, hyper_dict=hyper, out=out)
```

### Custom kernel explanation
This example defines a semiring where the "multiply" is `mul(a, b) = beta * (b - a)` and the "add" is standard addition. The epilogue then scales the accumulated sum by `gamma`.

With `beta=0.5` and `gamma=2.0`, the kernel computes:
```
out[m, n] = gamma * sum_k beta * (z[n, k] - x[m, k])
```
which simplifies to:
```
out[m, n] = sum_k (z[n, k] - x[m, k])
```

### HyperSemiringGradientKernel
Gradient kernels split the work into three stages:
- `multiply_lambda` computes a per-element contribution from `a`, `b`, and `d` (often a derivative-like term).
- `accumulate_lambda` combines that contribution with `c` and accumulates into `e` across `k`.
- `epilogue_lambda` applies any final transform to `e`.

In the common pattern used by the zoo, the math looks like:
```
out[o, n, m] = epilogue( sum_k c[o, k] * multiply(d[n, m], b[n, k], a[k, m]) )
```

Example gradient kernel sketch:
```python
import torch
from mm_kermac import Stack, PtxInstruction
from mm_kermac.hyper_semiring_gradient import HyperSemiringGradientKernel

device = torch.device("cuda")
grad_kernel_matrix = torch.randn(256, 128, device=device)  # a: (K, M)
x = torch.randn(512, 256, device=device)                   # b: (N, K)
coefs = torch.randn(64, 256, device=device)                # c: (O, K)
z = torch.randn(512, 128, device=device)                   # d: (N, M)
out = torch.empty((64, 512, 128), device=device)           # e: (O, N, M)

kernel = HyperSemiringGradientKernel(
    multiply_lambda=lambda d, b, a, reg: [
        b,  # push b
        d,  # push d
        PtxInstruction.sub_ftz_f32,  # diff = d - b
        reg["alpha"],  # push alpha (0.25 from hyper["alpha"])
        PtxInstruction.mul_ftz_f32,  # diff *= alpha
        a,  # push a
        PtxInstruction.mul_ftz_f32,  # diff *= a
    ],
    accumulate_lambda=lambda c, diff, e, reg: [
        c,  # push c
        diff,  # push diff
        PtxInstruction.mul_ftz_f32,  # c * diff
        e,  # push accumulator
        PtxInstruction.add_ftz_f32,  # acc += c * diff
    ],
    epilogue_lambda=lambda e, reg: [
        e,  # push accumulator
        reg["scale"],  # push scale (2.0 from hyper["scale"])
        PtxInstruction.mul_ftz_f32,  # scale output
    ],
)

hyper = {
    "alpha": torch.tensor(0.25, device=device),
    "scale": torch.tensor(2.0, device=device),
}
kernel(a=grad_kernel_matrix, b=x, c=coefs, d=z, hyper_dict=hyper, out=out)
```

With `alpha=0.25` and `scale=2.0`, this computes:
```
out[o, n, m] = scale * sum_k c[o, k] * (alpha * (z[n, m] - x[n, k]) * a[k, m])
```

See `examples/hyper_semiring_gradient.py` for concrete kernel definitions and expected shapes.

## mm-ptx
This repo relies on [mm-ptx](https://github.com/MetaMachines/mm-ptx-py) for Stack PTX and PTX Inject. Please see the repo for details on how `Stack PTX` works, how to use it, and simplified examples for using the system.

## Tests
Tests are GPU-backed and require CUDA with `sm_80` or greater. They are implemented as unittest copies of the examples.

Run all tests:
```bash
python -m unittest discover -s tests -p 'test_*.py' -v
```

If CUDA is unavailable or your GPU is below `sm_80`, the tests will be skipped.

## License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.

## Citation
If you use this software in your work, please cite it using the following BibTeX entry (generated from the [CITATION.cff](CITATION.cff) file):
```bibtex
@software{Durham_mm-kermac_2025,
  author       = {Durham, Charlie},
  title        = {mm-kermac: Dynamically compiled hyper semirings for Pytorch using PTX Inject and Stack PTX},
  version      = {1.0.0},
  date-released = {2025-10-19},
  url          = {https://github.com/MetaMachines/mm-kermac-py}
}
```
