Metadata-Version: 2.4
Name: zmlx
Version: 0.7.12
Summary: ZMLX: Metal-kernel toolkit and optimization lab for MLX on Apple Silicon. Fused MoE decode (+5-12% on LFM2-8B-A1B), custom GPU kernels in one line, 70+ kernel catalog.
Project-URL: Homepage, https://github.com/Hmbown/ZMLX
Project-URL: Repository, https://github.com/Hmbown/ZMLX
Project-URL: Documentation, https://github.com/Hmbown/ZMLX#readme
Project-URL: Issues, https://github.com/Hmbown/ZMLX/issues
Project-URL: Changelog, https://github.com/Hmbown/ZMLX/blob/main/CHANGELOG.md
Author: Hunter Bown
License: MIT
License-File: LICENSE
Keywords: apple-silicon,autograd,jit,kernels,metal,mlx,zmlx
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: MacOS :: MacOS X
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Software Development :: Libraries
Requires-Python: >=3.10
Requires-Dist: mlx>=0.30.0
Provides-Extra: dev
Requires-Dist: mypy>=1.10.0; extra == 'dev'
Requires-Dist: numpy>=1.24.0; extra == 'dev'
Requires-Dist: pytest>=8.0.0; extra == 'dev'
Requires-Dist: pyyaml>=6.0; extra == 'dev'
Requires-Dist: ruff>=0.5.0; extra == 'dev'
Requires-Dist: types-pyyaml>=6.0; extra == 'dev'
Provides-Extra: train
Requires-Dist: huggingface-hub>=0.20.0; extra == 'train'
Requires-Dist: mlx-lm>=0.25.0; extra == 'train'
Requires-Dist: pyyaml>=6.0; extra == 'train'
Requires-Dist: transformers>=4.40.0; extra == 'train'
Description-Content-Type: text/markdown

# ZMLX — Faster MoE inference on Apple Silicon

[![PyPI](https://img.shields.io/pypi/v/zmlx.svg)](https://pypi.org/project/zmlx/)
[![Python 3.10+](https://img.shields.io/badge/python-3.10%2B-blue.svg)](https://www.python.org/downloads/)
[![License: MIT](https://img.shields.io/badge/license-MIT-green.svg)](LICENSE)
[![Platform: macOS Apple Silicon](https://img.shields.io/badge/platform-macOS%20Apple%20Silicon-lightgrey.svg)](https://github.com/ml-explore/mlx)

ZMLX patches [MLX](https://github.com/ml-explore/mlx) models with fused Metal kernels for faster Mixture-of-Experts decode on Apple Silicon. **Stock MLX only** for the supported models. No model conversion, no config changes — just install and `patch(model)`.

Quick start (model patching; requires `zmlx[train]` or `mlx-lm`):

```python
import mlx_lm
from zmlx.patch import patch

model, tokenizer = mlx_lm.load("mlx-community/LFM2-8B-A1B-4bit")
patch(model)

text = mlx_lm.generate(model, tokenizer,
    prompt="Explain mixture-of-experts in one paragraph.",
    max_tokens=200)
```

---

## Install (recommended)

```bash
pip install "zmlx[train]"
```

That includes `mlx-lm` and everything needed for model patching. For kernel authoring only:

```bash
pip install zmlx
```

## Supported mode

| Mode | What you need | What you get |
|:--|:--|:--|
| **Stable** (default) | Stock MLX + `zmlx` | Token-identical output, real decode speedups |

`patch()` auto-detects which patterns are safe for your model family.

---

## Benchmarks — Stable mode (stock MLX)

LFM2-8B-A1B: **+5-12% decode**, token-identical, measured on M1 Pro and M4 Max. Prefill neutral by design (fused kernels activate only at M <= 32).

<details>
<summary>LFM2-8B-A1B on M1 Pro 16 GB</summary>

> macOS 14.6.1 · MLX 0.30.4 · ZMLX 0.7.12 · Python 3.10.0 · commit `7de879e`
>
> **Repro capsule:** [`benchmarks/repro_capsules/lfm2_m1pro_20260131.json`](benchmarks/repro_capsules/lfm2_m1pro_20260131.json) · **Print report:** `python -m zmlx.bench.report <capsule.json>`

**4-bit** ([mlx-community/LFM2-8B-A1B-4bit](https://huggingface.co/mlx-community/LFM2-8B-A1B-4bit))

| Metric | Baseline | Patched | Change |
|:--|--:|--:|:--|
| Decode | 105.5 tok/s | 115.3 tok/s | **+9.3%** |
| Prefill | 225.4 tok/s | 227.1 tok/s | +0.8% (neutral) |
| Fidelity | — | 430/430 | token-identical |
| Peak memory | — | 5.3 GB | |

**8-bit** ([mlx-community/LFM2-8B-A1B-8bit-MLX](https://huggingface.co/mlx-community/LFM2-8B-A1B-8bit-MLX))

| Metric | Baseline | Patched | Change |
|:--|--:|--:|:--|
| Decode | 72.8 tok/s | 76.4 tok/s | **+5.0%** |
| Prefill | 180.5 tok/s | 182.8 tok/s | +1.3% (neutral) |
| Fidelity | — | 500/500 | token-identical |
| Peak memory | — | 9.5 GB | |

</details>

<details>
<summary>LFM2-8B-A1B on M4 Max 36 GB</summary>

> macOS 26.1 · MLX 0.30.1 · ZMLX 0.7.12 · Python 3.12 · commit `139993e`
>
> **Repro capsule:** [`benchmarks/repro_capsules/lfm2_m4max_20260131.json`](benchmarks/repro_capsules/lfm2_m4max_20260131.json) · **Print report:** `python -m zmlx.bench.report <capsule.json>`

**4-bit** ([mlx-community/LFM2-8B-A1B-4bit](https://huggingface.co/mlx-community/LFM2-8B-A1B-4bit))

| Metric | Baseline | Patched | Change |
|:--|--:|--:|:--|
| Decode | 223.7 tok/s | 250.3 tok/s | **+11.9%** |
| Prefill | 737.4 tok/s | 755.4 tok/s | +2.4% (neutral) |
| Fidelity | — | 430/430 | token-identical |
| Peak memory | — | 5.30 GB | |

**8-bit** ([mlx-community/LFM2-8B-A1B-8bit-MLX](https://huggingface.co/mlx-community/LFM2-8B-A1B-8bit-MLX))

| Metric | Baseline | Patched | Change |
|:--|--:|--:|:--|
| Decode | 152.5 tok/s | 164.3 tok/s | **+7.7%** |
| Prefill | 557.6 tok/s | 564.4 tok/s | +1.2% (neutral) |
| Fidelity | — | 500/500 | token-identical |
| Peak memory | — | 9.45 GB | |

</details>

Full methodology and raw data: [`docs/BENCHMARKS.md`](docs/BENCHMARKS.md).

---

## How It Works

### The problem: dispatch overhead in MoE decode

In Mixture-of-Experts models, each token is routed to a subset of expert networks. During decode (generating one token at a time), the computation per expert is small — a few matrix multiplies on a single row vector. But the standard inference path dispatches multiple Metal kernels per expert per layer:

1. **Gating:** `softmax(logits)` → `argpartition` → `gather` → `normalize` — 4 dispatches
2. **Expert execution:** gate projection, up projection, SwiGLU activation, down projection — per expert
3. **Combine:** element-wise multiply by gating weights → reduce-sum across experts — 2 dispatches

On Apple Silicon, each Metal kernel dispatch has fixed overhead (command buffer encoding, GPU scheduling). When the actual compute per dispatch is small — as it is for M=1 decode — this overhead dominates. The GPU spends more time waiting between kernels than doing math.

<details>
<summary>What ZMLX fuses</summary>

ZMLX replaces the multi-dispatch sequences with single Metal kernels that do the same math in one pass. All fused kernels are generated from Python via [`mx.fast.metal_kernel`](https://ml-explore.github.io/mlx/build/html/python/fast.html) — no changes to MLX core required.

**Fused top-k gating softmax** (`topk_gating_softmax`):

Replaces the 4-dispatch gating sequence with a single kernel. For small expert counts (D <= 32, common in MoE), the kernel uses SIMD group operations — each row is processed by one SIMD group (32 threads), with `simd_max` and `simd_sum` for the softmax reduction and a register-based insertion sort for top-k selection. For larger D, a threadgroup reduction with shared memory is used. The kernel computes softmax probabilities and selects the top-k experts with their normalized weights in one pass.

**Fused expert combine** (`moe_combine`):

Replaces the separate element-wise multiply and reduce-sum with a single kernel that reads each expert output once, multiplies by its gating weight, and accumulates the weighted sum in float32. Output shape goes directly from `(B, K, D)` to `(B, D)` without materializing the intermediate `weights * expert_outputs` tensor.

### Why prefill is unaffected

All fused kernels are guarded with a sequence length check (`M <= 32`). During prefill, M equals the prompt length (typically hundreds or thousands of tokens). At this scale, the compute-to-dispatch ratio is high and the standard MLX path is already efficient. The guards ensure ZMLX never regresses prefill performance.

### Correctness guarantee

Token fidelity is a first-class requirement. `patch()` auto-detects the model family and excludes patterns with known fidelity issues. The fused gating kernel reproduces the exact same top-k selection and softmax normalization as the reference MLX ops. The combine kernel accumulates in float32 (or dtype-matched for Qwen3's `moe_combine_exact`). `python -m zmlx.validate` compares every generated token ID between patched and unpatched models under greedy decoding.

</details>

### Patching options

```python
from zmlx.patch import patch, smart_patch

patch(model)                       # auto-detect, apply safe defaults
patch(model, patterns=["moe_mlp"]) # force specific pattern (overrides safety)
patch(model, mode="training")      # add norm fusions for backward pass

# Auto-benchmark: apply only patterns that actually help
sample = mx.array([tokenizer.encode("Hello")])
model = smart_patch(model, sample)
```

---

## Model Support

### Stable

Token-identical output, measurable decode improvement. Safe to use without further validation.

| Model | Decode speedup | Fidelity | Patterns |
|:--|:--|:--|:--|
| **LFM2-8B-A1B-4bit** | **+9-12%** | token-identical | `moe_mlp` + `swiglu_mlp` |
| **LFM2-8B-A1B-8bit** | **+5-8%** | token-identical | `moe_mlp` + `swiglu_mlp` |
| **Qwen3-30B-A3B-4bit** | **+7%** | token-identical | `moe_mlp` |
| **GPT-OSS-20B-MXFP4-Q4** | **+1%** | token-identical | `moe_mlp` |

### Tested (no gain)

| Model | Status | Notes |
|:--|:--|:--|
| Nemotron-3-Nano-30B-A3B-NVFP4 | 0.999x, PASS | Hybrid Mamba-MoE, bandwidth-limited at 19.4 GB |
| LFM2.5-1.2B-Thinking-MLX-8bit | 0.997x, PASS | Dense model, no matched MoE patterns |
| Qwen3-4B-4bit (dense) | diverges at token 18 | Dense model, patches not expected to help |
| Llama-3.2-1B-4bit | 0.98x, PASS | Dense model, bandwidth-bound |

For unlisted models: `python -m zmlx.validate <model>`.

---

## Toolkit

ZMLX is also a Metal kernel authoring toolkit for MLX:

- **70+ kernel catalog** — SwiGLU, GeGLU, fused dropout, MoE gating, RMSNorm, RoPE, quantization
- **One-line kernel authoring** — `elementwise("x * tanh(log(1 + exp(x)))")` compiles to Metal
- **Automatic gradients** — custom VJP backward passes as Metal kernels via `mx.custom_function`
- **Benchmarking** — `zmlx.bench.compare()` for side-by-side timing, `zmlx.bench.report` for repro capsules

```python
from zmlx.api import elementwise
import mlx.core as mx

mish = elementwise("x * tanh(log(1 + exp(x)))", name="mish")
y = mish(mx.random.normal((1024,)))
```

<details>
<summary>Op-level microbenchmarks</summary>

B=16, S=1024, D=1024, float16, M4 Max. `python benchmarks/microbench.py`:

| Operation | MLX | ZMLX | Speedup |
|:--|--:|--:|:--|
| **SwiGLU** | 0.87 ms | **0.43 ms** | **2.0x** |
| **Dropout** | 3.08 ms | **0.41 ms** | **7.5x** |
| **Top-K** | 1.81 ms | **0.49 ms** | **3.7x** |
| **Gather-Add** | 0.55 ms | **0.42 ms** | **1.3x** |
| Softmax | 0.45 ms | 0.44 ms | ~1.0x |
| RMSNorm | 0.51 ms | 0.54 ms | 0.95x |

ZMLX helps most for **fused operations** that MLX doesn't provide as single ops. MLX built-ins (`mx.fast.rms_norm`, `mx.softmax`) are already highly optimized.

</details>

### Kernel catalog

70+ kernels organized by domain. Full reference: [`docs/KERNELS.md`](docs/KERNELS.md).

| Module | Highlights |
|:---|:---|
| `moe` | `topk_gating_softmax`, `moe_dispatch`, `moe_combine` — fused expert routing |
| `transformer` | `swiglu`, `geglu`, `rmsnorm_residual`, `dropout` — genuine fusions |
| `loss` | `softmax_cross_entropy` — memory-efficient fused loss |
| `bits` | `pack_bits`, `unpack_bits` — no MLX equivalent |
| `quant` | FP8, NF4, int8, int4 dequantization |
| `norms` | `rmsnorm`, `layernorm` — float32 internal compute |
| `rope` | `apply_rope`, `apply_rope_interleaved`, `apply_gqa_rope` |
| `optimizers` | `adamw_step` — fused parameter update |

---

## Install (full)

**Requirements**: macOS (Apple Silicon), Python >= 3.10, mlx >= 0.30.0

```bash
pip install "zmlx[train]"
```

Kernel authoring only:

```bash
pip install zmlx
```

From source:

```bash
git clone https://github.com/Hmbown/ZMLX.git
cd ZMLX
pip install -e ".[dev]"
```

---

## Quick Start

```python
from zmlx.api import elementwise
import mlx.core as mx

# Non-differentiable
fast_exp = elementwise("metal::exp(x)", name="fast_exp")
y = fast_exp(mx.random.normal((1024,)))

# Differentiable with custom VJP
from zmlx import msl

silu = elementwise(
    "kk_silu(x)",
    name="my_silu",
    grad_expr="g * (s + x * s * ((T)1 - s))",
    grad_prelude="T s = kk_sigmoid(x);",
    use_output=False,
    header=msl.DEFAULT_HEADER,
)
gx = mx.grad(lambda z: silu(z).sum())(mx.random.normal((1024,)))
```

<details>
<summary>More examples: reductions, softmax, testing</summary>

### Custom reduction

```python
from zmlx.api import reduce
import mlx.core as mx

my_sum = reduce(init="0.0f", update="acc + v", name="row_sum")
y = my_sum(mx.random.normal((8, 1024)))  # shape (8,)
```

### Two-pass map-reduce (softmax pattern)

```python
from zmlx.api import map_reduce
import mlx.core as mx

my_softmax = map_reduce(
    pass1={"init": "-INFINITY", "update": "max(acc1, x)", "reduce": "max(a, b)"},
    pass2={"init": "0.0f", "update": "acc2 + exp(x - s1)", "reduce": "a + b"},
    write="exp(x - s1) / s2",
    name="my_softmax",
)
y = my_softmax(mx.random.normal((8, 1024)))
```

### Test and benchmark

```python
import zmlx
import mlx.core as mx

zmlx.testing.assert_matches(
    my_softmax, lambda x: mx.softmax(x, axis=-1),
    shapes=[(8, 1024), (32, 4096)],
)

zmlx.bench.compare(
    {"ZMLX": my_softmax, "MLX": lambda x: mx.softmax(x, axis=-1)},
    shapes=[(1024, 4096), (4096, 4096)],
)
```

</details>

---

## Precision

All Python-level Metal kernels compute internally in **float32** regardless of input dtype. When exact dtype behavior matters (e.g., bfloat16 accumulation order), ZMLX provides specialized kernels to match MLX’s semantics.

---

## Documentation

- [`docs/TOUR.md`](docs/TOUR.md) — Quick walkthrough and orientation
- [`docs/QUICKSTART.md`](docs/QUICKSTART.md) — 5-minute tutorial
- [`docs/COOKBOOK.md`](docs/COOKBOOK.md) — Recipes for common patterns
- [`docs/KERNELS.md`](docs/KERNELS.md) — Complete kernel catalog
- [`docs/BENCHMARKS.md`](docs/BENCHMARKS.md) — Detailed benchmark methodology
- [`docs/EXPERIMENTAL_MLX.md`](docs/EXPERIMENTAL_MLX.md) — Optional custom‑MLX experiments
- [`docs/ARCHITECTURE.md`](docs/ARCHITECTURE.md) — Design philosophy
- [`UPSTREAM_PLAN.md`](UPSTREAM_PLAN.md) — What belongs upstream in MLX

---

## Acknowledgments

ZMLX is built on [MLX](https://github.com/ml-explore/mlx) by Apple machine learning research. If you use ZMLX in your work, please also cite MLX:

```bibtex
@software{mlx2023,
  author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
  title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
  url = {https://github.com/ml-explore},
  version = {0.0},
  year = {2023},
}
```

---

## Contributing

See [`CONTRIBUTING.md`](CONTRIBUTING.md) for setup, testing, and conventions.

---

## License

MIT. See [`LICENSE`](LICENSE).
