Metadata-Version: 2.3
Name: jax2onnx
Version: 0.2.0
Summary: export JAX to ONNX - focus on flax nnx
Author: enpasos
Author-email: matthias.unverzagt@enpasos.ai
Requires-Python: >=3.12,<4
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Requires-Dist: flax (>=0.10.4)
Requires-Dist: jax (>=0.5.2)
Requires-Dist: ml_dtypes (==0.5.1)
Requires-Dist: netron (>=8.1.9)
Requires-Dist: onnx (>=1.17.0)
Requires-Dist: onnxruntime (>=1.21.0)
Requires-Dist: optax (==0.2.4)
Requires-Dist: orbax-checkpoint (==0.11.6)
Requires-Dist: orbax-export (==0.0.6)
Description-Content-Type: text/markdown

# jax2onnx 🌟

`jax2onnx` converts your JAX/Flax functions directly into the ONNX format.

![img.png](https://enpasos.github.io/jax2onnx/images/jax2onnx.png)

**Key features:**

- **JAXPR-Based Conversion:** Uses JAX's built-in `jaxpr` as the foundation for conversion.
- **Flexible Plugin System:** Easy-to-write Python plugins to handle specific JAX primitives or examples. Plugin registration and testing is automated.



---

## 🚀 Quickstart

Here's how simple it is to convert your JAX callable to ONNX:

```python
from jax2onnx import save_onnx
from flax import nnx 

# Example: A minimal MLP (from Flax documentation)
class MLP(nnx.Module):
  def __init__(self, din, dmid, dout, *, rngs):
    self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
    self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)
  def __call__(self, x):
    x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
    return self.linear2(x)

# Instantiate model
my_callable = MLP(din=30, dmid=20, dout=10, rngs=nnx.Rngs(0))

# Convert and save to ONNX
save_onnx(
    my_callable,
    [('B', 30)],          # Input shapes, batch size 'B' is symbolic
    "my_callable.onnx"    # Output path
)
```
The result: [`my_callable.onnx`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/my_callable.onnx)


---

## 📅 Roadmap and Releases

### **Planned Version**
- **0.3.0** *(Upcoming)*: Simplifying the plugin mechanism.

### **Current Productive Version**
- **0.2.0** *(Released on PyPI)*: Rebased the implementation on `jaxpr`, improving usability and adding low-level `lax` components.

### **Past Versions**
- **0.1.0** *(Initial Approach, Not Released to PyPI)*: Produced ONNX exports for some `nnx` components and `nnx`-based examples, including a VisualTransformer.

---

## ❓ Troubleshooting

If conversion doesn't work out of the box, it could be due to:

- **Non-dynamic function references:**  
  JAXPR-based conversion requires function references to be resolved dynamically at call-time.  
  **Solution:** Wrap your function call inside a lambda to enforce dynamic resolution:
  ```python
  my_dynamic_callable_function = lambda x: original_function(x)
  ```

- **Unsupported primitives:**  
  The callable may use a primitive not yet or not fully supported by `jax2onnx`.  
  **Solution:** Write a [plugin](#how-to-contribute) to handle the unsupported function (this is straightforward!).

---

## 🧩 Supported JAX/ONNX Components


<!-- AUTOGENERATED TABLE START -->

| JAX Component | ONNX Components | Testcases | Since |
|:-------------|:---------------|:---------|:------|
| [jnp.add](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.add.html) | [Add](https://onnx.ai/onnx/operators/onnx__Add.html) | [`add`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/add.onnx) ✅ | v0.1.0 |
| [jnp.concat](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.concat.html) | [Concat](https://onnx.ai/onnx/operators/onnx__Concat.html) | [`concat`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/concat.onnx) ✅ | v0.1.0 |
| [jnp.einsum](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.einsum.html) | [Einsum](https://onnx.ai/onnx/operators/onnx__Einsum.html) | [`einsum_dynamic_concrete`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/einsum_dynamic_concrete.onnx) ✅<br>[`einsum_dynamic_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/einsum_dynamic_dynamic.onnx) ✅<br>[`einsum_dynamic_matmul2_concrete`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/einsum_dynamic_matmul2_concrete.onnx) ✅<br>[`einsum_dynamic_matmul2_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/einsum_dynamic_matmul2_dynamic.onnx) ✅<br>[`einsum_dynamic_matmul3_concrete`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/einsum_dynamic_matmul3_concrete.onnx) ✅<br>[`einsum_dynamic_matmul3_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/einsum_dynamic_matmul3_dynamic.onnx) ✅<br>[`einsum_dynamic_matmul_concrete`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/einsum_dynamic_matmul_concrete.onnx) ✅<br>[`einsum_dynamic_matmul_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/einsum_dynamic_matmul_dynamic.onnx) ✅<br>[`einsum_dynamic_transpose_concrete`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/einsum_dynamic_transpose_concrete.onnx) ✅<br>[`einsum_dynamic_transpose_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/einsum_dynamic_transpose_dynamic.onnx) ✅<br>[`einsum_matmul`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/einsum_matmul.onnx) ✅<br>[`einsum_transpose`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/einsum_transpose.onnx) ✅<br>[`einsum`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/einsum.onnx) ✅ | v0.1.0 |
| [jnp.matmul](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.matmul.html) | [MatMul](https://onnx.ai/onnx/operators/onnx__MatMul.html) | [`matmul_1d_2d`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/matmul_1d_2d.onnx) ✅<br>[`matmul_1d`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/matmul_1d.onnx) ✅<br>[`matmul_2d_1d`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/matmul_2d_1d.onnx) ✅<br>[`matmul_2d`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/matmul_2d.onnx) ✅<br>[`matmul_3d`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/matmul_3d.onnx) ✅<br>[`matmul_dynamic_a_concrete`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/matmul_dynamic_a_concrete.onnx) ✅<br>[`matmul_dynamic_a_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/matmul_dynamic_a_dynamic.onnx) ✅<br>[`matmul_dynamic_b_concrete`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/matmul_dynamic_b_concrete.onnx) ✅<br>[`matmul_dynamic_b_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/matmul_dynamic_b_dynamic.onnx) ✅<br>[`matmul_dynamic_concrete`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/matmul_dynamic_concrete.onnx) ✅<br>[`matmul_dynamic_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/matmul_dynamic_dynamic.onnx) ✅ | v0.1.0 |
| [jnp.reshape](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.reshape.html) | [Reshape](https://onnx.ai/onnx/operators/onnx__Reshape.html) | [`reshape_1`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/reshape_1.onnx) ✅<br>[`reshape_2`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/reshape_2.onnx) ✅<br>[`reshape_3`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/reshape_3.onnx) ✅<br>[`reshape_4_concrete`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/reshape_4_concrete.onnx) ✅<br>[`reshape_4_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/reshape_4_dynamic.onnx) ✅<br>[`reshape_from_scalar`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/reshape_from_scalar.onnx) ✅<br>[`reshape_to_scalar`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/reshape_to_scalar.onnx) ✅ | v0.1.0 |
| [jnp.squeeze](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.squeeze.html) | [Squeeze](https://onnx.ai/onnx/operators/onnx__Squeeze.html) | [`squeeze_all_dims`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/squeeze_all_dims.onnx) ✅<br>[`squeeze_dynamic_and_negative_axis_concrete`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/squeeze_dynamic_and_negative_axis_concrete.onnx) ✅<br>[`squeeze_dynamic_and_negative_axis_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/squeeze_dynamic_and_negative_axis_dynamic.onnx) ✅<br>[`squeeze_dynamic_batch_concrete`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/squeeze_dynamic_batch_concrete.onnx) ✅<br>[`squeeze_dynamic_batch_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/squeeze_dynamic_batch_dynamic.onnx) ✅<br>[`squeeze_multiple_dims`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/squeeze_multiple_dims.onnx) ✅<br>[`squeeze_negative_axis_tuple`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/squeeze_negative_axis_tuple.onnx) ✅<br>[`squeeze_negative_axis`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/squeeze_negative_axis.onnx) ✅<br>[`squeeze_single_dim`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/squeeze_single_dim.onnx) ✅<br>[`squeeze_vit_output`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/squeeze_vit_output.onnx) ✅ | v0.1.0 |
| [jnp.tile](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.tile.html) | [Tile](https://onnx.ai/onnx/operators/onnx__Tile.html) | [`tile_a`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/tile_a.onnx) ✅<br>[`tile_b`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/tile_b.onnx) ✅<br>[`tile_c`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/tile_c.onnx) ✅<br>[`tile_d`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/tile_d.onnx) ✅<br>[`tile_dynamic_concrete`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/tile_dynamic_concrete.onnx) ✅<br>[`tile_dynamic_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/tile_dynamic_dynamic.onnx) ✅<br>[`tile_pad`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/tile_pad.onnx) ✅ | v0.1.0 |
| [jnp.transpose](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.transpose.html) | [Transpose](https://onnx.ai/onnx/operators/onnx__Transpose.html) | [`transpose_4d`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/transpose_4d.onnx) ✅<br>[`transpose_basic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/transpose_basic.onnx) ✅<br>[`transpose_dynamic_concrete`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/transpose_dynamic_concrete.onnx) ✅<br>[`transpose_dynamic_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/transpose_dynamic_dynamic.onnx) ✅<br>[`transpose_high_dim`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/transpose_high_dim.onnx) ✅<br>[`transpose_no_axes`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/transpose_no_axes.onnx) ✅<br>[`transpose_reverse`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/transpose_reverse.onnx) ✅<br>[`transpose_square_matrix`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/jnp/transpose_square_matrix.onnx) ✅ | v0.1.0 |
| [lax.add](https://docs.jax.dev/en/latest/_autosummary/jax.lax.add.html) | [Add](https://onnx.ai/onnx/operators/onnx__Add.html) | [`add`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/add.onnx) ✅ | v0.2.0 |
| [lax.argmax](https://docs.jax.dev/en/latest/_autosummary/jax.lax.argmax.html) | [ArgMax](https://onnx.ai/onnx/operators/onnx__ArgMax.html) | [`argmax_test1`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/argmax_test1.onnx) ✅<br>[`argmax_test2`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/argmax_test2.onnx) ✅ | v0.2.0 |
| [lax.argmin](https://docs.jax.dev/en/latest/_autosummary/jax.lax.argmin.html) | [ArgMin](https://onnx.ai/onnx/operators/onnx__ArgMin.html) | [`argmin_test1`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/argmin_test1.onnx) ✅<br>[`argmin_test2`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/argmin_test2.onnx) ✅ | v0.2.0 |
| [lax.broadcast_in_dim](https://docs.jax.dev/en/latest/_autosummary/jax.lax.broadcast_in_dim.html) | [Expand](https://onnx.ai/onnx/operators/onnx__Expand.html) | [`broadcast_in_dim`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/broadcast_in_dim.onnx) ✅ | v0.2.0 |
| [lax.concatenate](https://docs.jax.dev/en/latest/_autosummary/jax.lax.concatenate.html) | [Concat](https://onnx.ai/onnx/operators/onnx__Concat.html) | [`concatenate`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/concatenate.onnx) ✅ | v0.2.0 |
| [lax.conv](https://docs.jax.dev/en/latest/_autosummary/jax.lax.conv.html) | [Conv](https://onnx.ai/onnx/operators/onnx__Conv.html) | [`conv2`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/conv2.onnx) ✅<br>[`conv`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/conv.onnx) ✅ | v0.2.0 |
| [lax.convert_element_type](https://docs.jax.dev/en/latest/_autosummary/jax.lax.convert_element_type.html) | [Cast](https://onnx.ai/onnx/operators/onnx__Cast.html) | [`convert_element_type`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/convert_element_type.onnx) ✅ | v0.2.0 |
| [lax.div](https://docs.jax.dev/en/latest/_autosummary/jax.lax.div.html) | [Div](https://onnx.ai/onnx/operators/onnx__Div.html) | [`div`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/div.onnx) ✅ | v0.2.0 |
| [lax.dot_general](https://docs.jax.dev/en/latest/_autosummary/jax.lax.dot_general.html) | [MatMul](https://onnx.ai/onnx/operators/onnx__MatMul.html) | [`dot_general`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/dot_general.onnx) ✅ | v0.2.0 |
| [lax.dynamic_slice](https://docs.jax.dev/en/latest/_autosummary/jax.lax.dynamic_slice.html) | [Slice](https://onnx.ai/onnx/operators/onnx__Slice.html) | [`dynamic_slice_test1`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/dynamic_slice_test1.onnx) ✅ | v0.1.0 |
| [lax.eq](https://docs.jax.dev/en/latest/_autosummary/jax.lax.eq.html) | [Equal](https://onnx.ai/onnx/operators/onnx__Equal.html) | [`eq`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/eq.onnx) ✅ | v0.2.0 |
| [lax.exp](https://docs.jax.dev/en/latest/_autosummary/jax.lax.exp.html) | [Exp](https://onnx.ai/onnx/operators/onnx__Exp.html) | [`exp`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/exp.onnx) ✅ | v0.2.0 |
| [lax.gather](https://docs.jax.dev/en/latest/_autosummary/jax.lax.gather.html) | [Gather](https://onnx.ai/onnx/operators/onnx__Gather.html) | [`gather`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/gather.onnx) ✅ | v0.2.0 |
| [lax.gt](https://docs.jax.dev/en/latest/_autosummary/jax.lax.gt.html) | [Greater](https://onnx.ai/onnx/operators/onnx__Greater.html) | [`gt`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/gt.onnx) ✅ | v0.2.0 |
| [lax.integer_pow](https://docs.jax.dev/en/latest/_autosummary/jax.lax.integer_pow.html) | [Pow](https://onnx.ai/onnx/operators/onnx__Pow.html) | [`integer_pow`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/integer_pow.onnx) ✅ | v0.2.0 |
| [lax.log](https://docs.jax.dev/en/latest/_autosummary/jax.lax.log.html) | [Log](https://onnx.ai/onnx/operators/onnx__Log.html) | [`log`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/log.onnx) ✅ | v0.2.0 |
| [lax.lt](https://docs.jax.dev/en/latest/_autosummary/jax.lax.lt.html) | [Less](https://onnx.ai/onnx/operators/onnx__Less.html) | [`lt`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/lt.onnx) ✅ | v0.2.0 |
| [lax.max](https://docs.jax.dev/en/latest/_autosummary/jax.lax.max.html) | [Max](https://onnx.ai/onnx/operators/onnx__Max.html) | [`max`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/max.onnx) ✅ | v0.2.0 |
| [lax.min](https://docs.jax.dev/en/latest/_autosummary/jax.lax.min.html) | [Min](https://onnx.ai/onnx/operators/onnx__Min.html) | [`min_test1`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/min_test1.onnx) ✅ | v0.1.0 |
| [lax.mul](https://docs.jax.dev/en/latest/_autosummary/jax.lax.mul.html) | [Mul](https://onnx.ai/onnx/operators/onnx__Mul.html) | [`mul_test1`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/mul_test1.onnx) ✅<br>[`mul_test2`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/mul_test2.onnx) ✅ | v0.1.0 |
| [lax.ne](https://docs.jax.dev/en/latest/_autosummary/jax.lax.ne.html) | [Equal](https://onnx.ai/onnx/operators/onnx__Equal.html)<br>[Not](https://onnx.ai/onnx/operators/onnx__Not.html) | [`ne`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/ne.onnx) ✅ | v0.2.0 |
| [lax.neg](https://docs.jax.dev/en/latest/_autosummary/jax.lax.neg.html) | [Neg](https://onnx.ai/onnx/operators/onnx__Neg.html) | [`neg`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/neg.onnx) ✅ | v0.2.0 |
| [lax.reduce_max](https://docs.jax.dev/en/latest/_autosummary/jax.lax.reduce_max.html) | [ReduceMax](https://onnx.ai/onnx/operators/onnx__ReduceMax.html) | [`reduce_max`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/reduce_max.onnx) ✅ | v0.2.0 |
| [lax.reduce_min](https://docs.jax.dev/en/latest/_autosummary/jax.lax.reduce_min.html) | [ReduceMin](https://onnx.ai/onnx/operators/onnx__ReduceMin.html) | [`reduce_min`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/reduce_min.onnx) ✅ | v0.2.0 |
| [lax.reduce_sum](https://docs.jax.dev/en/latest/_autosummary/jax.lax.reduce_sum.html) | [ReduceSum](https://onnx.ai/onnx/operators/onnx__ReduceSum.html) | [`reduce_sum`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/reduce_sum.onnx) ✅ | v0.2.0 |
| [lax.reshape](https://docs.jax.dev/en/latest/_autosummary/jax.lax.reshape.html) | [Reshape](https://onnx.ai/onnx/operators/onnx__Reshape.html) | [`reshape`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/reshape.onnx) ✅ | v0.2.0 |
| [lax.slice](https://docs.jax.dev/en/latest/_autosummary/jax.lax.slice.html) | [Slice](https://onnx.ai/onnx/operators/onnx__Slice.html) | [`slice_test1`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/slice_test1.onnx) ✅ | v0.1.0 |
| [lax.sort](https://docs.jax.dev/en/latest/_autosummary/jax.lax.sort.html) | [TopK](https://onnx.ai/onnx/operators/onnx__TopK.html) | [`sort_1d_empty`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/sort_1d_empty.onnx) ✅<br>[`sort_1d_larger`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/sort_1d_larger.onnx) ✅<br>[`sort_1d_single`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/sort_1d_single.onnx) ✅<br>[`sort_1d_specific_values`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/sort_1d_specific_values.onnx) ✅<br>[`sort_1d`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/sort_1d.onnx) ✅ | v0.2.0 |
| [lax.sqrt](https://docs.jax.dev/en/latest/_autosummary/jax.lax.sqrt.html) | [Sqrt](https://onnx.ai/onnx/operators/onnx__Sqrt.html) | [`sqrt`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/sqrt.onnx) ✅ | v0.2.0 |
| [lax.square](https://docs.jax.dev/en/latest/_autosummary/jax.lax.square.html) | [Mul](https://onnx.ai/onnx/operators/onnx__Mul.html) | [`square`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/square.onnx) ✅ | v0.2.0 |
| [lax.squeeze](https://docs.jax.dev/en/latest/_autosummary/jax.lax.squeeze.html) | [Squeeze](https://onnx.ai/onnx/operators/onnx__Squeeze.html) | [`squeeze`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/squeeze.onnx) ✅ | v0.2.0 |
| [lax.stop_gradient](https://docs.jax.dev/en/latest/_autosummary/jax.lax.stop_gradient.html) | [Identity](https://onnx.ai/onnx/operators/onnx__Identity.html) | [`stop_gradient`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/stop_gradient.onnx) ✅ | v0.2.0 |
| [lax.sub](https://docs.jax.dev/en/latest/_autosummary/jax.lax.sub.html) | [Sub](https://onnx.ai/onnx/operators/onnx__Sub.html) | [`sub_test1`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/sub_test1.onnx) ✅<br>[`sub_test2`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/sub_test2.onnx) ✅ | v0.1.0 |
| [lax.tanh](https://docs.jax.dev/en/latest/_autosummary/jax.lax.tanh.html) | [Tanh](https://onnx.ai/onnx/operators/onnx__Tanh.html) | [`tanh`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/tanh.onnx) ✅ | v0.2.0 |
| [lax.transpose](https://docs.jax.dev/en/latest/_autosummary/jax.lax.transpose.html) | [Transpose](https://onnx.ai/onnx/operators/onnx__Transpose.html) | [`transpose_basic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/lax/transpose_basic.onnx) ✅ | v0.2.0 |
| [nn.softmax](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.softmax.html) | [Softmax](https://onnx.ai/onnx/operators/onnx__Softmax.html) | [`softmax`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nn/softmax.onnx) ✅ | v0.1.0 |
| [nnx.avg_pool](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#flax.linen.avg_pool) | [AveragePool](https://onnx.ai/onnx/operators/onnx__AveragePool.html)<br>[Transpose](https://onnx.ai/onnx/operators/onnx__Transpose.html) | [`avg_pool_default_padding`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/avg_pool_default_padding.onnx) ✅<br>[`avg_pool_same_padding`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/avg_pool_same_padding.onnx) ✅<br>[`avg_pool`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/avg_pool.onnx) ✅ | v0.1.0 |
| [nnx.batch_norm](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) | [BatchNormalization](https://onnx.ai/onnx/operators/onnx__BatchNormalization.html) | [`batch_norm_2`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/batch_norm_2.onnx) ✅<br>[`batch_norm`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/batch_norm.onnx) ✅ | v0.1.0 |
| [nnx.conv](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_general_dilated.html) | [Conv](https://onnx.ai/onnx/operators/onnx__Conv.html)<br>[Transpose](https://onnx.ai/onnx/operators/onnx__Transpose.html) | [`conv_2`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/conv_2.onnx) ✅<br>[`conv_3`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/conv_3.onnx) ✅<br>[`conv_4`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/conv_4.onnx) ✅<br>[`conv_concrete`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/conv_concrete.onnx) ✅<br>[`conv_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/conv_dynamic.onnx) ✅ | v0.1.0 |
| [nnx.dot_product_attention](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/attention.html#flax.nnx.dot_product_attention) | [Constant](https://onnx.ai/onnx/operators/onnx__Constant.html)<br>[Einsum](https://onnx.ai/onnx/operators/onnx__Einsum.html)<br>[Mul](https://onnx.ai/onnx/operators/onnx__Mul.html)<br>[Softmax](https://onnx.ai/onnx/operators/onnx__Softmax.html) | [`dot_product_attention`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/dot_product_attention.onnx) ✅ | v0.1.0 |
| [nnx.dropout](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) | [Dropout](https://onnx.ai/onnx/operators/onnx__Dropout.html) | [`dropout_inference`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/dropout_inference.onnx) ✅ | v0.1.0 |
| [nnx.elu](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.elu.html) | [Elu](https://onnx.ai/onnx/operators/onnx__Elu.html) | [`elu`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/elu.onnx) ✅ | v0.1.0 |
| [nnx.gelu](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.gelu.html) | [Gelu](https://onnx.ai/onnx/operators/onnx__Gelu.html) | [`gelu_1`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/gelu_1.onnx) ✅<br>[`gelu_2`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/gelu_2.onnx) ✅<br>[`gelu_3`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/gelu_3.onnx) ✅<br>[`gelu`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/gelu.onnx) ✅ | v0.1.0 |
| [nnx.layer_norm](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.LayerNorm) | [LayerNormalization](https://onnx.ai/onnx/operators/onnx__LayerNormalization.html) | [`layer_norm_multiaxis`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/layer_norm_multiaxis.onnx) ✅<br>[`layer_norm`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/layer_norm.onnx) ✅ | v0.1.0 |
| [nnx.leaky_relu](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.leaky_relu.html) | [LeakyRelu](https://onnx.ai/onnx/operators/onnx__LeakyRelu.html) | [`leaky_relu`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/leaky_relu.onnx) ✅ | v0.1.0 |
| [nnx.linear](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html) | [Gemm](https://onnx.ai/onnx/operators/onnx__Gemm.html)<br>[Reshape](https://onnx.ai/onnx/operators/onnx__Reshape.html) | [`linear_2d`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/linear_2d.onnx) ✅<br>[`linear_concrete`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/linear_concrete.onnx) ✅<br>[`linear_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/linear_dynamic.onnx) ✅ | v0.1.0 |
| [nnx.linear_general](https://docs.jax.dev/en/latest/_autosummary/jax.lax.dot_general.html) | [Gemm](https://onnx.ai/onnx/operators/onnx__Gemm.html)<br>[Reshape](https://onnx.ai/onnx/operators/onnx__Reshape.html) | [`linear_general_2`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/linear_general_2.onnx) ✅<br>[`linear_general_3`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/linear_general_3.onnx) ✅<br>[`linear_general_4`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/linear_general_4.onnx) ✅<br>[`linear_general_concrete`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/linear_general_concrete.onnx) ✅<br>[`linear_general_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/linear_general_dynamic.onnx) ✅ | v0.1.0 |
| [nnx.log_softmax](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.log_softmax.html) | [LogSoftmax](https://onnx.ai/onnx/operators/onnx__LogSoftmax.html) | [`log_softmax`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/log_softmax.onnx) ✅ | v0.1.0 |
| [nnx.max_pool](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#flax.linen.max_pool) | [MaxPool](https://onnx.ai/onnx/operators/onnx__MaxPool.html)<br>[Transpose](https://onnx.ai/onnx/operators/onnx__Transpose.html) | [`max_pool_same_padding`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/max_pool_same_padding.onnx) ✅<br>[`max_pool`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/max_pool.onnx) ✅ | v0.1.0 |
| [nnx.relu](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.relu.html) | [Relu](https://onnx.ai/onnx/operators/onnx__Relu.html) | [`relu_2`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/relu_2.onnx) ✅<br>[`relu`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/relu.onnx) ✅ | v0.1.0 |
| [nnx.sigmoid](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/activations.html#flax.nnx.sigmoid) | [Sigmoid](https://onnx.ai/onnx/operators/onnx__Sigmoid.html) | [`sigmoid`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/sigmoid.onnx) ✅ | v0.1.0 |
| [nnx.softmax](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.softmax.html) | [Softmax](https://onnx.ai/onnx/operators/onnx__Softmax.html) | [`softmax`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/softmax.onnx) ✅ | v0.1.0 |
| [nnx.softplus](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.softplus.html) | [Softplus](https://onnx.ai/onnx/operators/onnx__Softplus.html) | [`softplus`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/softplus.onnx) ✅ | v0.1.0 |
| [nnx.tanh](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/activations.html#flax.nnx.tanh) | [Tanh](https://onnx.ai/onnx/operators/onnx__Tanh.html) | [`tanh`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/plugins/nnx/tanh.onnx) ✅ | v0.1.0 |

<!-- AUTOGENERATED TABLE END -->

**Legend:**  
✅ = Passed  
❌ = Failed  
➖ = No testcase yet

---

## 🎯 Examples

<!-- AUTOGENERATED EXAMPLES TABLE START -->

| Component | Description | Children | Testcases | Since |
|:----------|:------------|:---------|:---------|:------|
| AutoEncoder | A simple autoencoder example. | Encoder<br>Decoder | [`autoencoder`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/examples/nnx/autoencoder.onnx) ✅ | v0.2.0 |
| CNN | A simple convolutional neural network (CNN). | nnx.Conv<br>nnx.Linear<br>nnx.avg_pool<br>nnx.relu<br>lax.reshape | [`cnn`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/examples/nnx/cnn.onnx) ✅ | v0.1.0 |
| ConvEmbedding | Convolutional Token Embedding for MNIST with hierarchical downsampling. | flax.nnx.Conv<br>flax.nnx.LayerNorm<br>jax.numpy.Reshape<br>jax.nn.relu | [`mnist_conv_embedding`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/examples/nnx/mnist_conv_embedding.onnx) ✅ | v0.1.0 |
| MLP | A simple Multi-Layer Perceptron (MLP) with BatchNorm, Dropout, and GELU activation. | nnx.Linear<br>nnx.Dropout<br>nnx.BatchNorm<br>nnx.gelu | [`mlp_concrete`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/examples/nnx/mlp_concrete.onnx) ✅<br>[`mlp_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/examples/nnx/mlp_dynamic.onnx) ✅ | v0.1.0 |
| MLPBlock | MLP in Transformer | flax.nnx.Linear<br>flax.nnx.Dropout<br>flax.nnx.gelu | [`mlp_block`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/examples/nnx/mlp_block.onnx) ✅ | v0.1.0 |
| MultiHeadAttention | This is a multi-head attention module implementes by Flax/nnx that has no ONNX correspondant on the same granularity. | nnx.GeneralLinear<br>nnx.dot_product_attention | [`multihead_attention`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/examples/nnx/multihead_attention.onnx) ✅ | v0.2.0 |
| PatchEmbedding | Cutting the image into patches and linearly embedding them. | flax.nnx.Linear<br>jax.numpy.Transpose<br>jax.numpy.Reshape | [`patch_embedding`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/examples/nnx/patch_embedding.onnx) ✅ | v0.1.0 |
| TransformerBlock | Transformer from 'Attention Is All You Need.' | flax.nnx.MultiHeadAttention<br>flax.nnx.LayerNorm<br>MLPBlock<br>flax.nnx.Dropout | [`transformer_block`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/examples/nnx/transformer_block.onnx) ✅ | v0.1.0 |
| VisionTransformer | A MNIST Vision Transformer (ViT) model with configurable convolutional or patch embedding. | PatchEmbedding<br>ConvEmbedding<br>TransformerBlock<br>nnx.Linear<br>nnx.LayerNorm | [`mnist_vit_conv`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/examples/nnx/mnist_vit_conv.onnx) ✅<br>[`mnist_vit_patch`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/examples/nnx/mnist_vit_patch.onnx) ✅ | v0.1.0 |

<!-- AUTOGENERATED EXAMPLES TABLE END -->

---

## 📌 Dependencies

**Versions of Major Dependencies:**

| Library       | jax2onnx v0.2.0 | 
|:--------------|:----------------| 
| `JAX`         | v0.5.2          | 
| `Flax`        | v0.10.4         | 
| `onnx`        | v1.17.0         |  
| `onnxruntime` | v1.21.0         |  

*Note: For more details, check `pyproject.toml`.*

---

## ⚠️ Limitations

- Currently not all JAX/Flax components are supported (you can easily help expand this coverage!).
- Function references need dynamic resolution at call-time.
- ONNX graph composition is done in-memory before saving to disk, potentially causing memory issues with very large models.

---

## 🤝 How to Contribute

We warmly welcome contributions!

**How you can help:**

- **Add a plugin:** Extend `jax2onnx` by writing a simple Python file in [`jax2onnx/converter/plugins`](./jax2onnx/converter/plugins).
- **Provide examples:** Add an illustrative example to the [`examples`](./jax2onnx/examples) folder.
- **Bug fixes & improvements:** PRs and issues are always welcome.

---

## 💾 Installation

Install from PyPI:

```bash
pip install jax2onnx  
```

Or get the latest development version from TestPyPI:

```bash
pip install -i https://test.pypi.org/simple/ jax2onnx
```

---

## 📜 License

This project is licensed under the Apache License, Version 2.0. See [`LICENSE`](./LICENSE) for details.

---

## 🌟 Special Thanks

Special thanks to the community members involved in:

- [Flax Feature Request #4430](https://github.com/google/flax/issues/4430)
- [JAX Feature Request #26430](https://github.com/jax-ml/jax/issues/26430)

A huge thanks especially to [@limarta](https://github.com/limarta), whose elegant [jaxpr-to-ONNX demonstration](https://gist.github.com/limarta/855a88cc1c0163487a9dc369891147ab) significantly inspired this project.

---

**Happy converting! 🎉**



