Metadata-Version: 2.4
Name: inox
Version: 0.7.2
Summary: Stainless neural networks in JAX
Author-email: François Rozet <francois.rozet@outlook.com>
License-Expression: MIT
Project-URL: documentation, https://inox.readthedocs.io
Project-URL: source, https://github.com/francois-rozet/inox
Project-URL: tracker, https://github.com/francois-rozet/inox/issues
Keywords: jax,pytree,neural networks,deep learning
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: Natural Language :: English
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Requires-Python: >=3.9
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: einops>=0.5.0
Requires-Dist: jax>=0.4.26
Provides-Extra: docs
Requires-Dist: docutils==0.19; extra == "docs"
Requires-Dist: furo==2024.5.6; extra == "docs"
Requires-Dist: myst-nb==1.0.0; extra == "docs"
Requires-Dist: sphinx==7.3.7; extra == "docs"
Provides-Extra: lint
Requires-Dist: pre-commit>=3.7.0; extra == "lint"
Requires-Dist: ruff==0.9.9; extra == "lint"
Provides-Extra: test
Requires-Dist: pytest>=8.0.0; extra == "test"
Dynamic: license-file

![Inox's banner](https://raw.githubusercontent.com/francois-rozet/inox/master/docs/images/banner.svg)

# Stainless neural networks in JAX

Inox is a minimal [JAX](https://github.com/google/jax) library for neural networks with an intuitive [PyTorch](https://github.com/pytorch/pytorch)-like syntax. As with [Equinox](https://github.com/patrick-kidger/equinox), modules are represented as PyTrees, which enables complex architectures, easy manipulations, and functional transformations.

Inox aims to be a leaner version of Equinox by only retaining its core features: PyTrees and lifted transformations. In addition, Inox takes inspiration from other projects like [NNX](https://github.com/cgarciae/nnx) and [Serket](https://github.com/ASEM000/serket) to provide a versatile interface. Despite the differences, Inox remains compatible with the Equinox ecosystem, and its components (modules, transformations, ...) are for the most part interchangeable with those of Equinox.

> Inox means "stainless steel" in French 🔪

## Installation

The `inox` package is available on [PyPI](https://pypi.org/project/inox), which means it is installable via `pip`.

```
pip install inox
```

Alternatively, if you need the latest features, you can install it from the repository.

```
pip install git+https://github.com/francois-rozet/inox
```

## Getting started

Modules are defined with an intuitive PyTorch-like syntax,

```python
import jax
import inox.nn as nn

init_key, data_key = jax.random.split(jax.random.key(0))

class MLP(nn.Module):
    def __init__(self, key):
        keys = jax.random.split(key, 3)

        self.l1 = nn.Linear(3, 64, key=keys[0])
        self.l2 = nn.Linear(64, 64, key=keys[1])
        self.l3 = nn.Linear(64, 3, key=keys[2])
        self.relu = nn.ReLU()

    def __call__(self, x):
        x = self.l1(x)
        x = self.l2(self.relu(x))
        x = self.l3(self.relu(x))

        return x

model = MLP(init_key)
```

and are compatible with JAX transformations.

```python
X = jax.random.normal(data_key, (1024, 3))
Y = jax.numpy.sort(X, axis=-1)

@jax.jit
def loss_fn(model, x, y):
    pred = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred) ** 2)

grads = jax.grad(loss_fn)(model, X, Y)
```

However, if a tree contains strings or boolean flags, it becomes incompatible with JAX transformations. For this reason, Inox provides lifted transformations that consider all non-array leaves as static.

```python
model.name = 'stainless'  # not an array

@inox.jit
def loss_fn(model, x, y):
    pred = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred) ** 2)

grads = inox.grad(loss_fn)(model, X, Y)
```

Inox also provides a partition mechanism to split the static definition of a module (structure, strings, flags, ...) from its dynamic content (parameters, indices, statistics, ...), which is convenient for updating parameters.

```python
model.mask = jax.numpy.array([1, 0, 1])  # not a parameter

static, params, others = model.partition(nn.Parameter)

@jax.jit
def loss_fn(params, others, x, y):
    model = static(arrays, others)
    pred = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred) ** 2)

grads = jax.grad(loss_fn)(params, others, X, Y)
params = jax.tree_util.tree_map(lambda p, g: p - 0.01 * g, params, grads)

model = static(params, others)
```

For more information, check out the documentation and tutorials at [inox.readthedocs.io](https://inox.readthedocs.io).

## Contributing

If you have a question, an issue or would like to contribute, please read our [contributing guidelines](https://github.com/francois-rozet/inox/blob/master/CONTRIBUTING.md).
