Metadata-Version: 2.3
Name: jax2onnx
Version: 0.3.2
Summary: export JAX to ONNX - focus on flax nnx
Author: enpasos
Author-email: matthias.unverzagt@enpasos.ai
Requires-Python: >=3.10
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Requires-Dist: flax (>=0.10.4)
Requires-Dist: jax (>=0.5.2)
Requires-Dist: ml_dtypes (==0.5.1)
Requires-Dist: netron (>=8.1.9)
Requires-Dist: onnx (>=1.17.0)
Requires-Dist: onnxruntime (>=1.21.0)
Requires-Dist: optax (==0.2.4)
Requires-Dist: orbax-checkpoint (==0.11.6)
Requires-Dist: orbax-export (==0.0.6)
Description-Content-Type: text/markdown

# jax2onnx 🌟

`jax2onnx` converts your JAX/Flax functions directly into the ONNX format.

![img.png](https://enpasos.github.io/jax2onnx/images/jax2onnx.png)

**Key features:**

- **JAXPR-Based Conversion:** Uses JAX's built-in `jaxpr` as the foundation for conversion.
- **Flexible Plugin System:** Easy-to-write Python plugins to handle specific JAX primitives or examples. Plugin registration and testing is automated.



---

## 🚀 Quickstart

Here's how simple it is to convert your JAX callable to ONNX:

```python
from jax2onnx import save_onnx
from flax import nnx 

# Example: A minimal MLP (from Flax documentation)
class MLP(nnx.Module):
  def __init__(self, din, dmid, dout, *, rngs):
    self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
    self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)
  def __call__(self, x):
    x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
    return self.linear2(x)

# Instantiate model
my_callable = MLP(din=30, dmid=20, dout=10, rngs=nnx.Rngs(0))

# Convert and save to ONNX
save_onnx(
    my_callable,
    [('B', 30)],          # Input shapes, batch size 'B' is symbolic
    "my_callable.onnx"    # Output path
)
```
The result: [`my_callable.onnx`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/my_callable.onnx)


---

## 📅 Roadmap and Releases


### **Planned Version**
- **Ongoing**: Expanding test coverage.
- **0.4.0** *(Upcoming)*: Introducing ONNX function support.


### **Current Productive Version**
- **0.3.0** *(PyPI Release)*: Streamlined the plugin system with automatic registration and simplified integration of custom primitives.

### **Past Versions**
- **0.2.0** *(First PyPI Release)*: Rebased the implementation on `jaxpr`, improving usability and adding low-level `lax` components.
- **0.1.0** *(Initial Approach, Not Released to PyPI)*: Produced ONNX exports for some `nnx` components and `nnx`-based examples, including a VisualTransformer.

---

## ❓ Troubleshooting

If conversion doesn't work out of the box, it could be due to:

- **Non-dynamic function references:**  
  JAXPR-based conversion requires function references to be resolved dynamically at call-time.  
  **Solution:** Wrap your function call inside a lambda to enforce dynamic resolution:
  ```python
  my_dynamic_callable_function = lambda x: original_function(x)
  ```

- **Unsupported primitives:**  
  The callable may use a primitive not yet or not fully supported by `jax2onnx`.  
  **Solution:** Write a [plugin](#how-to-contribute) to handle the unsupported function (this is straightforward!).

---

## 🧩 Supported JAX/ONNX Components


<!-- AUTOGENERATED TABLE START -->

| JAX Component | ONNX Components | Testcases | Since |
|:-------------|:---------------|:---------|:------|
| [jnp.add](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.add.html) | [Add](https://onnx.ai/onnx/operators/onnx__Add.html) | [`add`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/add.onnx) ✅ | v0.1.0 |
| [jnp.concat](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.concat.html) | [Concat](https://onnx.ai/onnx/operators/onnx__Concat.html) | [`concat`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/concat.onnx) ✅ | v0.1.0 |
| [jnp.einsum](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.einsum.html) | [Einsum](https://onnx.ai/onnx/operators/onnx__Einsum.html) | [`einsum`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/einsum.onnx) ✅<br>[`einsum_matmul`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/einsum_matmul.onnx) ✅<br>[`einsum_dynamic_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/einsum_dynamic_dynamic.onnx) ✅<br>[`einsum_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/einsum_dynamic.onnx) ✅<br>[`einsum_dynamic_matmul_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/einsum_dynamic_matmul_dynamic.onnx) ✅<br>[`einsum_dynamic_matmul`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/einsum_dynamic_matmul.onnx) ✅<br>[`einsum_transpose`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/einsum_transpose.onnx) ✅<br>[`einsum_dynamic_transpose_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/einsum_dynamic_transpose_dynamic.onnx) ✅<br>[`einsum_dynamic_transpose`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/einsum_dynamic_transpose.onnx) ✅<br>[`einsum_dynamic_matmul2_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/einsum_dynamic_matmul2_dynamic.onnx) ✅<br>[`einsum_dynamic_matmul2`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/einsum_dynamic_matmul2.onnx) ✅<br>[`einsum_dynamic_matmul3_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/einsum_dynamic_matmul3_dynamic.onnx) ✅<br>[`einsum_dynamic_matmul3`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/einsum_dynamic_matmul3.onnx) ✅ | v0.1.0 |
| [jnp.matmul](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.matmul.html) | [MatMul](https://onnx.ai/onnx/operators/onnx__MatMul.html) | [`matmul_2d`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/matmul_2d.onnx) ✅<br>[`matmul_1d_2d`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/matmul_1d_2d.onnx) ✅<br>[`matmul_2d_1d`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/matmul_2d_1d.onnx) ✅<br>[`matmul_dynamic_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/matmul_dynamic_dynamic.onnx) ✅<br>[`matmul_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/matmul_dynamic.onnx) ✅<br>[`matmul_dynamic_a_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/matmul_dynamic_a_dynamic.onnx) ✅<br>[`matmul_dynamic_a`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/matmul_dynamic_a.onnx) ✅<br>[`matmul_dynamic_b_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/matmul_dynamic_b_dynamic.onnx) ✅<br>[`matmul_dynamic_b`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/matmul_dynamic_b.onnx) ✅<br>[`matmul_1d`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/matmul_1d.onnx) ✅<br>[`matmul_3d`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/matmul_3d.onnx) ✅ | v0.1.0 |
| [jnp.reshape](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.reshape.html) | [Reshape](https://onnx.ai/onnx/operators/onnx__Reshape.html) | [`reshape_1`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/reshape_1.onnx) ✅<br>[`reshape_2`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/reshape_2.onnx) ✅<br>[`reshape_3`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/reshape_3.onnx) ✅<br>[`reshape_4_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/reshape_4_dynamic.onnx) ✅<br>[`reshape_4`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/reshape_4.onnx) ✅<br>[`reshape_to_scalar`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/reshape_to_scalar.onnx) ✅<br>[`reshape_from_scalar`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/reshape_from_scalar.onnx) ✅ | v0.1.0 |
| [jnp.squeeze](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.squeeze.html) | [Squeeze](https://onnx.ai/onnx/operators/onnx__Squeeze.html) | [`squeeze_single_dim`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/squeeze_single_dim.onnx) ✅<br>[`squeeze_multiple_dims`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/squeeze_multiple_dims.onnx) ✅<br>[`squeeze_vit_output`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/squeeze_vit_output.onnx) ✅<br>[`squeeze_dynamic_batch_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/squeeze_dynamic_batch_dynamic.onnx) ✅<br>[`squeeze_dynamic_batch`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/squeeze_dynamic_batch.onnx) ✅<br>[`squeeze_all_dims`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/squeeze_all_dims.onnx) ✅<br>[`squeeze_negative_axis`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/squeeze_negative_axis.onnx) ✅<br>[`squeeze_negative_axis_tuple`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/squeeze_negative_axis_tuple.onnx) ✅<br>[`squeeze_dynamic_and_negative_axis_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/squeeze_dynamic_and_negative_axis_dynamic.onnx) ✅<br>[`squeeze_dynamic_and_negative_axis`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/squeeze_dynamic_and_negative_axis.onnx) ✅ | v0.1.0 |
| [jnp.tile](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.tile.html) | [Tile](https://onnx.ai/onnx/operators/onnx__Tile.html) | [`tile_a`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/tile_a.onnx) ✅<br>[`tile_b`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/tile_b.onnx) ✅<br>[`tile_c`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/tile_c.onnx) ✅<br>[`tile_d`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/tile_d.onnx) ✅<br>[`tile_dynamic_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/tile_dynamic_dynamic.onnx) ✅<br>[`tile_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/tile_dynamic.onnx) ✅<br>[`tile_pad`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/tile_pad.onnx) ✅ | v0.1.0 |
| [jnp.transpose](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.transpose.html) | [Transpose](https://onnx.ai/onnx/operators/onnx__Transpose.html) | [`transpose_basic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/transpose_basic.onnx) ✅<br>[`transpose_reverse`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/transpose_reverse.onnx) ✅<br>[`transpose_4d`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/transpose_4d.onnx) ✅<br>[`transpose_square_matrix`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/transpose_square_matrix.onnx) ✅<br>[`transpose_high_dim`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/transpose_high_dim.onnx) ✅<br>[`transpose_no_axes`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/transpose_no_axes.onnx) ✅<br>[`transpose_dynamic_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/transpose_dynamic_dynamic.onnx) ✅<br>[`transpose_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/jnp/transpose_dynamic.onnx) ✅ | v0.1.0 |
| [add](https://docs.jax.dev/en/latest/_autosummary/jax.lax.add.html) | [Add](https://onnx.ai/onnx/operators/onnx__Add.html) | [`add`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/add.onnx) ✅ | v0.2.0 |
| [argmax](https://docs.jax.dev/en/latest/_autosummary/jax.lax.argmax.html) | [ArgMax](https://onnx.ai/onnx/operators/onnx__ArgMax.html) | [`argmax_test1`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/argmax_test1.onnx) ✅<br>[`argmax_test2`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/argmax_test2.onnx) ✅ | v0.2.0 |
| [argmin](https://docs.jax.dev/en/latest/_autosummary/jax.lax.argmin.html) | [ArgMin](https://onnx.ai/onnx/operators/onnx__ArgMin.html) | [`argmin_test1`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/argmin_test1.onnx) ✅<br>[`argmin_test2`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/argmin_test2.onnx) ✅ | v0.2.0 |
| [broadcast_in_dim](https://docs.jax.dev/en/latest/_autosummary/jax.lax.broadcast_in_dim.html) | [Expand](https://onnx.ai/onnx/operators/onnx__Expand.html) | [`broadcast_in_dim`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/broadcast_in_dim.onnx) ✅<br>[`broadcast_in_dim_2d_to_3d`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/broadcast_in_dim_2d_to_3d.onnx) ✅<br>[`broadcast_in_dim_scalar`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/broadcast_in_dim_scalar.onnx) ✅ | v0.2.0 |
| [concatenate](https://docs.jax.dev/en/latest/_autosummary/jax.lax.concatenate.html) | [Concat](https://onnx.ai/onnx/operators/onnx__Concat.html) | [`concatenate`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/concatenate.onnx) ✅<br>[`concatenate_axis1`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/concatenate_axis1.onnx) ✅<br>[`concatenate_dynamic_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/concatenate_dynamic_dynamic.onnx) ✅<br>[`concatenate_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/concatenate_dynamic.onnx) ✅<br>[`concatenate_3d`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/concatenate_3d.onnx) ✅ | v0.2.0 |
| [conv_general_dilated](https://docs.jax.dev/en/latest/_autosummary/jax.lax.conv.html) | [Conv](https://onnx.ai/onnx/operators/onnx__Conv.html) | [`conv`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/conv.onnx) ✅<br>[`conv2`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/conv2.onnx) ✅ | v0.2.0 |
| [convert_element_type](https://docs.jax.dev/en/latest/_autosummary/jax.lax.convert_element_type.html) | [Cast](https://onnx.ai/onnx/operators/onnx__Cast.html) | [`convert_element_type`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/convert_element_type.onnx) ✅ | v0.2.0 |
| [div](https://docs.jax.dev/en/latest/_autosummary/jax.lax.div.html) | [Div](https://onnx.ai/onnx/operators/onnx__Div.html) | [`div`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/div.onnx) ✅ | v0.2.0 |
| [dot_general](https://docs.jax.dev/en/latest/_autosummary/jax.lax.dot_general.html) | [MatMul](https://onnx.ai/onnx/operators/onnx__MatMul.html) | [`dot_general`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/dot_general.onnx) ✅ | v0.2.0 |
| [dynamic_slice](https://docs.jax.dev/en/latest/_autosummary/jax.lax.dynamic_slice.html) | [Slice](https://onnx.ai/onnx/operators/onnx__Slice.html) | [`dynamic_slice_test1`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/dynamic_slice_test1.onnx) ✅<br>[`dynamic_slice_2d`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/dynamic_slice_2d.onnx) ✅<br>[`dynamic_slice_3d`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/dynamic_slice_3d.onnx) ✅ | v0.1.0 |
| [eq](https://docs.jax.dev/en/latest/_autosummary/jax.lax.eq.html) | [Equal](https://onnx.ai/onnx/operators/onnx__Equal.html) | [`eq`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/eq.onnx) ✅ | v0.2.0 |
| [exp](https://docs.jax.dev/en/latest/_autosummary/jax.lax.exp.html) | [Exp](https://onnx.ai/onnx/operators/onnx__Exp.html) | [`exp`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/exp.onnx) ✅ | v0.2.0 |
| [gather](https://docs.jax.dev/en/latest/_autosummary/jax.lax.gather.html) | [Gather](https://onnx.ai/onnx/operators/onnx__Gather.html) | [`gather`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/gather.onnx) ✅ | v0.2.0 |
| [gt](https://docs.jax.dev/en/latest/_autosummary/jax.lax.gt.html) | [Greater](https://onnx.ai/onnx/operators/onnx__Greater.html) | [`gt`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/gt.onnx) ✅ | v0.2.0 |
| [integer_pow](https://docs.jax.dev/en/latest/_autosummary/jax.lax.integer_pow.html) | [Pow](https://onnx.ai/onnx/operators/onnx__Pow.html) | [`integer_pow`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/integer_pow.onnx) ✅ | v0.2.0 |
| [log](https://docs.jax.dev/en/latest/_autosummary/jax.lax.log.html) | [Log](https://onnx.ai/onnx/operators/onnx__Log.html) | [`log`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/log.onnx) ✅ | v0.2.0 |
| [lt](https://docs.jax.dev/en/latest/_autosummary/jax.lax.lt.html) | [Less](https://onnx.ai/onnx/operators/onnx__Less.html) | [`lt`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/lt.onnx) ✅ | v0.2.0 |
| [max](https://docs.jax.dev/en/latest/_autosummary/jax.lax.max.html) | [Max](https://onnx.ai/onnx/operators/onnx__Max.html) | [`max`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/max.onnx) ✅ | v0.2.0 |
| [min](https://docs.jax.dev/en/latest/_autosummary/jax.lax.min.html) | [Min](https://onnx.ai/onnx/operators/onnx__Min.html) | [`min_test1`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/min_test1.onnx) ✅ | v0.1.0 |
| [mul](https://docs.jax.dev/en/latest/_autosummary/jax.lax.mul.html) | [Mul](https://onnx.ai/onnx/operators/onnx__Mul.html) | [`mul_test1`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/mul_test1.onnx) ✅<br>[`mul_test2`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/mul_test2.onnx) ✅ | v0.1.0 |
| [ne](https://docs.jax.dev/en/latest/_autosummary/jax.lax.ne.html) | [Equal](https://onnx.ai/onnx/operators/onnx__Equal.html)<br>[Not](https://onnx.ai/onnx/operators/onnx__Not.html) | [`ne`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/ne.onnx) ✅ | v0.2.0 |
| [neg](https://docs.jax.dev/en/latest/_autosummary/jax.lax.neg.html) | [Neg](https://onnx.ai/onnx/operators/onnx__Neg.html) | [`neg`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/neg.onnx) ✅ | v0.2.0 |
| [reduce_max](https://docs.jax.dev/en/latest/_autosummary/jax.lax.reduce_max.html) | [ReduceMax](https://onnx.ai/onnx/operators/onnx__ReduceMax.html) | [`reduce_max`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/reduce_max.onnx) ✅ | v0.2.0 |
| [reduce_min](https://docs.jax.dev/en/latest/_autosummary/jax.lax.reduce_min.html) | [ReduceMin](https://onnx.ai/onnx/operators/onnx__ReduceMin.html) | [`reduce_min`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/reduce_min.onnx) ✅ | v0.2.0 |
| [reduce_sum](https://docs.jax.dev/en/latest/_autosummary/jax.lax.reduce_sum.html) | [ReduceSum](https://onnx.ai/onnx/operators/onnx__ReduceSum.html) | [`reduce_sum`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/reduce_sum.onnx) ✅ | v0.2.0 |
| [reshape](https://docs.jax.dev/en/latest/_autosummary/jax.lax.reshape.html) | [Reshape](https://onnx.ai/onnx/operators/onnx__Reshape.html) | [`reshape`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/reshape.onnx) ✅ | v0.2.0 |
| [slice](https://docs.jax.dev/en/latest/_autosummary/jax.lax.slice.html) | [Slice](https://onnx.ai/onnx/operators/onnx__Slice.html) | [`slice_test1`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/slice_test1.onnx) ✅ | v0.1.0 |
| [sort](https://docs.jax.dev/en/latest/_autosummary/jax.lax.sort.html) | [TopK](https://onnx.ai/onnx/operators/onnx__TopK.html) | [`sort_1d`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/sort_1d.onnx) ✅<br>[`sort_1d_empty`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/sort_1d_empty.onnx) ✅<br>[`sort_1d_single`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/sort_1d_single.onnx) ✅<br>[`sort_1d_larger`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/sort_1d_larger.onnx) ✅<br>[`sort_1d_specific_values`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/sort_1d_specific_values.onnx) ✅ | v0.2.0 |
| [sqrt](https://docs.jax.dev/en/latest/_autosummary/jax.lax.sqrt.html) | [Sqrt](https://onnx.ai/onnx/operators/onnx__Sqrt.html) | [`sqrt`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/sqrt.onnx) ✅ | v0.2.0 |
| [square](https://docs.jax.dev/en/latest/_autosummary/jax.lax.square.html) | [Mul](https://onnx.ai/onnx/operators/onnx__Mul.html) | [`square`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/square.onnx) ✅ | v0.2.0 |
| [squeeze](https://docs.jax.dev/en/latest/_autosummary/jax.lax.squeeze.html) | [Squeeze](https://onnx.ai/onnx/operators/onnx__Squeeze.html) | [`squeeze`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/squeeze.onnx) ✅ | v0.2.0 |
| [stop_gradient](https://docs.jax.dev/en/latest/_autosummary/jax.lax.stop_gradient.html) | [Identity](https://onnx.ai/onnx/operators/onnx__Identity.html) | [`stop_gradient`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/stop_gradient.onnx) ✅ | v0.2.0 |
| [sub](https://docs.jax.dev/en/latest/_autosummary/jax.lax.sub.html) | [Sub](https://onnx.ai/onnx/operators/onnx__Sub.html) | [`sub_test1`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/sub_test1.onnx) ✅<br>[`sub_test2`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/sub_test2.onnx) ✅ | v0.1.0 |
| [tanh](https://docs.jax.dev/en/latest/_autosummary/jax.lax.tanh.html) | [Tanh](https://onnx.ai/onnx/operators/onnx__Tanh.html) | [`tanh`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/tanh.onnx) ✅ | v0.2.0 |
| [transpose](https://docs.jax.dev/en/latest/_autosummary/jax.lax.transpose.html) | [Transpose](https://onnx.ai/onnx/operators/onnx__Transpose.html) | [`transpose_basic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/lax/transpose_basic.onnx) ✅ | v0.2.0 |
| [nn.softmax](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.softmax.html) | [Softmax](https://onnx.ai/onnx/operators/onnx__Softmax.html) | [`softmax`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nn/softmax.onnx) ✅<br>[`softmax_2d`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nn/softmax_2d.onnx) ✅<br>[`softmax_3d`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nn/softmax_3d.onnx) ✅ | v0.1.0 |
| [nnx.avg_pool](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#flax.linen.avg_pool) | [AveragePool](https://onnx.ai/onnx/operators/onnx__AveragePool.html)<br>[Transpose](https://onnx.ai/onnx/operators/onnx__Transpose.html) | [`avg_pool`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/avg_pool.onnx) ✅<br>[`avg_pool_same_padding`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/avg_pool_same_padding.onnx) ✅<br>[`avg_pool_default_padding`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/avg_pool_default_padding.onnx) ✅ | v0.1.0 |
| [nnx.batch_norm](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) | [BatchNormalization](https://onnx.ai/onnx/operators/onnx__BatchNormalization.html) | [`batch_norm`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/batch_norm.onnx) ✅<br>[`batch_norm_2`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/batch_norm_2.onnx) ✅<br>[`batch_norm_3d`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/batch_norm_3d.onnx) ✅<br>[`batch_norm_float64`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/batch_norm_float64.onnx) ✅<br>[`batch_norm_single_batch`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/batch_norm_single_batch.onnx) ✅ | v0.1.0 |
| [nnx.conv](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Conv) | [Conv](https://onnx.ai/onnx/operators/onnx__Conv.html)<br>[Transpose](https://onnx.ai/onnx/operators/onnx__Transpose.html) | [`conv_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/conv_dynamic.onnx) ✅<br>[`conv`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/conv.onnx) ✅<br>[`conv_2`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/conv_2.onnx) ✅<br>[`conv_3`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/conv_3.onnx) ✅<br>[`conv_4`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/conv_4.onnx) ✅ | v0.1.0 |
| [nnx.conv_transpose](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/conv_transpose.html) | [ConvTranspose](https://onnx.ai/onnx/operators/onnx__ConvTranspose.html) | [`conv_transpose_valid_padding`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/conv_transpose_valid_padding.onnx) ✅<br>[`conv_transpose_circular_padding`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/conv_transpose_circular_padding.onnx) ✅ | v0.3.0 |
| [nnx.dot_product_attention](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/attention.html#flax.nnx.dot_product_attention) | [Constant](https://onnx.ai/onnx/operators/onnx__Constant.html)<br>[Einsum](https://onnx.ai/onnx/operators/onnx__Einsum.html)<br>[Mul](https://onnx.ai/onnx/operators/onnx__Mul.html)<br>[Softmax](https://onnx.ai/onnx/operators/onnx__Softmax.html) | [`dot_product_attention`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/dot_product_attention.onnx) ✅ | v0.1.0 |
| [nnx.dropout](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) | [Dropout](https://onnx.ai/onnx/operators/onnx__Dropout.html) | [`dropout_inference`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/dropout_inference.onnx) ✅ | v0.1.0 |
| [nnx.elu](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.elu.html) | [Elu](https://onnx.ai/onnx/operators/onnx__Elu.html) | [`elu`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/elu.onnx) ✅ | v0.1.0 |
| [nnx.gelu](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.gelu.html) | [Gelu](https://onnx.ai/onnx/operators/onnx__Gelu.html) | [`gelu`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/gelu.onnx) ✅<br>[`gelu_1`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/gelu_1.onnx) ✅<br>[`gelu_2`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/gelu_2.onnx) ✅<br>[`gelu_3`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/gelu_3.onnx) ✅ | v0.1.0 |
| [nnx.group_norm](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.GroupNorm) | [GroupNormalization](https://example.com/onnx_GroupNormalization_doc) | [`group_norm`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/group_norm.onnx) ✅<br>[`group_norm_2`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/group_norm_2.onnx) ✅ | v0.3.0 |
| [nnx.layer_norm](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.LayerNorm) | [LayerNormalization](https://onnx.ai/onnx/operators/onnx__LayerNormalization.html) | [`layer_norm`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/layer_norm.onnx) ✅<br>[`layer_norm_multiaxis`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/layer_norm_multiaxis.onnx) ✅ | v0.1.0 |
| [nnx.leaky_relu](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.leaky_relu.html) | [LeakyRelu](https://onnx.ai/onnx/operators/onnx__LeakyRelu.html) | [`leaky_relu`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/leaky_relu.onnx) ✅ | v0.1.0 |
| [nnx.linear](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html) | [Gemm](https://onnx.ai/onnx/operators/onnx__Gemm.html)<br>[Reshape](https://onnx.ai/onnx/operators/onnx__Reshape.html) | [`linear_2d`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/linear_2d.onnx) ✅<br>[`linear_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/linear_dynamic.onnx) ✅<br>[`linear`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/linear.onnx) ✅ | v0.1.0 |
| [nnx.linear_general](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.LinearGeneral) | [Gemm](https://onnx.ai/onnx/operators/onnx__Gemm.html)<br>[Reshape](https://onnx.ai/onnx/operators/onnx__Reshape.html) | [`linear_general_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/linear_general_dynamic.onnx) ✅<br>[`linear_general`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/linear_general.onnx) ✅<br>[`linear_general_2`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/linear_general_2.onnx) ✅<br>[`linear_general_3`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/linear_general_3.onnx) ✅<br>[`linear_general_4`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/linear_general_4.onnx) ✅ | v0.1.0 |
| [nnx.log_softmax](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.log_softmax.html) | [LogSoftmax](https://onnx.ai/onnx/operators/onnx__LogSoftmax.html) | [`log_softmax`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/log_softmax.onnx) ✅ | v0.1.0 |
| [nnx.max_pool](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#flax.linen.max_pool) | [MaxPool](https://onnx.ai/onnx/operators/onnx__MaxPool.html)<br>[Transpose](https://onnx.ai/onnx/operators/onnx__Transpose.html) | [`max_pool`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/max_pool.onnx) ✅<br>[`max_pool_same_padding`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/max_pool_same_padding.onnx) ✅ | v0.1.0 |
| [nnx.relu](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.relu.html) | [Relu](https://onnx.ai/onnx/operators/onnx__Relu.html) | [`relu`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/relu.onnx) ✅<br>[`relu_2`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/relu_2.onnx) ✅ | v0.1.0 |
| [nnx.rms_norm](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.RMSNorm) | [RMSNormalization](https://example.com/onnx_RMSNormalization_doc) | [`rms_norm`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/rms_norm.onnx) ✅<br>[`rms_norm_2`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/rms_norm_2.onnx) ✅ | v0.3.0 |
| [nnx.sigmoid](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/activations.html#flax.nnx.sigmoid) | [Sigmoid](https://onnx.ai/onnx/operators/onnx__Sigmoid.html) | [`sigmoid`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/sigmoid.onnx) ✅ | v0.1.0 |
| [nnx.softmax](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.softmax.html) | [Softmax](https://onnx.ai/onnx/operators/onnx__Softmax.html) | [`softmax`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/softmax.onnx) ✅ | v0.1.0 |
| [nnx.softplus](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.softplus.html) | [Softplus](https://onnx.ai/onnx/operators/onnx__Softplus.html) | [`softplus`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/softplus.onnx) ✅ | v0.1.0 |
| [nnx.tanh](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/activations.html#flax.nnx.tanh) | [Tanh](https://onnx.ai/onnx/operators/onnx__Tanh.html) | [`tanh`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/primitives/nnx/tanh.onnx) ✅ | v0.1.0 |

<!-- AUTOGENERATED TABLE END -->

**Legend:**  
✅ = Passed  
❌ = Failed  
➖ = No testcase yet

---

## 🎯 Examples

<!-- AUTOGENERATED EXAMPLES TABLE START -->

| Component | Description | Children | Testcases | Since |
|:----------|:------------|:---------|:---------|:------|
| AutoEncoder | A simple autoencoder example. | Encoder<br>Decoder | [`autoencoder`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/examples/nnx/autoencoder.onnx) ✅ | v0.2.0 |
| CNN | A simple convolutional neural network (CNN). | nnx.Conv<br>nnx.Linear<br>nnx.avg_pool<br>nnx.relu<br>lax.reshape | [`cnn`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/examples/nnx/cnn.onnx) ✅ | v0.1.0 |
| ConvEmbedding | Convolutional Token Embedding for MNIST with hierarchical downsampling. | flax.nnx.Conv<br>flax.nnx.LayerNorm<br>jax.numpy.Reshape<br>jax.nn.relu | [`mnist_conv_embedding`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/examples/nnx/mnist_conv_embedding.onnx) ✅ | v0.1.0 |
| MLP | A simple Multi-Layer Perceptron (MLP) with BatchNorm, Dropout, and GELU activation. | nnx.Linear<br>nnx.Dropout<br>nnx.BatchNorm<br>nnx.gelu | [`mlp_dynamic`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/examples/nnx/mlp_dynamic.onnx) ✅<br>[`mlp`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/examples/nnx/mlp.onnx) ✅ | v0.1.0 |
| MLPBlock | MLP in Transformer | flax.nnx.Linear<br>flax.nnx.Dropout<br>flax.nnx.gelu | [`mlp_block`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/examples/nnx/mlp_block.onnx) ✅ | v0.1.0 |
| MultiHeadAttention | This is a multi-head attention module implemented by Flax/nnx that has no ONNX correspondent on the same granularity. | nnx.GeneralLinear<br>nnx.dot_product_attention | [`multihead_attention`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/examples/nnx/multihead_attention.onnx) ✅ | v0.2.0 |
| PatchEmbedding | Cutting the image into patches and linearly embedding them. | flax.nnx.Linear<br>jax.numpy.Transpose<br>jax.numpy.Reshape | [`patch_embedding`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/examples/nnx/patch_embedding.onnx) ✅ | v0.1.0 |
| TransformerBlock | Transformer from 'Attention Is All You Need.' | flax.nnx.MultiHeadAttention<br>flax.nnx.LayerNorm<br>MLPBlock<br>flax.nnx.Dropout | [`transformer_block`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/examples/nnx/transformer_block.onnx) ✅ | v0.1.0 |
| VisionTransformer | A Vision Transformer (ViT) model for MNIST with configurable embedding type. | PatchEmbedding<br>ConvEmbedding<br>MLPBlock<br>TransformerBlock<br>nnx.MultiHeadAttention<br>nnx.LayerNorm<br>nnx.Linear<br>nnx.gelu<br>nnx.Dropout<br>nnx.Param | [`vit_conv_embedding`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/examples/nnx/vit_conv_embedding.onnx) ✅<br>[`vit_patch_embedding`](https://netron.app/?url=https://enpasos.github.io/jax2onnx/onnx/examples/nnx/vit_patch_embedding.onnx) ✅ | v0.2.0 |

<!-- AUTOGENERATED EXAMPLES TABLE END -->

---

## 📌 Dependencies

**Versions of Major Dependencies:**

| Library       | Versions |  
|:--------------|:---------| 
| `JAX`         | 0.5.3    | 
| `Flax`        | 0.10.4   | 
| `onnx`        | 1.17.0   |  
| `onnxruntime` | 1.21.0   |  

*Note: For more details, check `pyproject.toml`.*

---

## ⚠️ Limitations

- Currently not all JAX/Flax components are supported (you can easily help expand this coverage!).
- Function references need dynamic resolution at call-time.
- ONNX graph composition is done in-memory before saving to disk, potentially causing memory issues with very large models.

---

## 🤝 How to Contribute

We warmly welcome contributions!

**How you can help:**

- **Add a plugin:** Extend `jax2onnx` by writing a simple Python file in [`jax2onnx/converter/plugins`](./jax2onnx/converter/plugins).
- **Provide examples:** Add an illustrative example to the [`examples`](./jax2onnx/examples) folder.
- **Bug fixes & improvements:** PRs and issues are always welcome.

---

## 💾 Installation

Install from PyPI:

```bash
pip install jax2onnx  
```

Or get the latest development version from TestPyPI:

```bash
pip install -i https://test.pypi.org/simple/ jax2onnx
```

---

## 📜 License

This project is licensed under the Apache License, Version 2.0. See [`LICENSE`](./LICENSE) for details.

---

## 🌟 Special Thanks

Special thanks to the community members involved in:

- [Flax Feature Request #4430](https://github.com/google/flax/issues/4430)
- [JAX Feature Request #26430](https://github.com/jax-ml/jax/issues/26430)

A huge thanks especially to [@limarta](https://github.com/limarta), whose elegant [jaxpr-to-ONNX demonstration](https://gist.github.com/limarta/855a88cc1c0163487a9dc369891147ab) significantly inspired this project.

---

**Happy converting! 🎉**



