Metadata-Version: 2.4
Name: blaze-torch
Version: 0.0.4
Summary: A PyTorch adapter for forward-only model definition
License: MIT
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=2.0
Provides-Extra: dev
Requires-Dist: pytest>=8.0; extra == "dev"
Dynamic: description
Dynamic: description-content-type
Dynamic: license
Dynamic: license-file
Dynamic: provides-extra
Dynamic: requires-dist
Dynamic: requires-python
Dynamic: summary

<div align="center">
  <img src="https://raw.githubusercontent.com/baosws/blaze/main/assets/blaze.png" alt="blaze logo" height="200"/>
  <h1><b></b>Blaze: Write less. Build more.</h1>
</div>

<div align="center">

[![PyPI](https://img.shields.io/pypi/v/blaze-torch)](https://pypi.org/project/blaze-torch/)

</div>

A PyTorch adapter inspired by [Haiku's](https://dm-haiku.readthedocs.io) functional programming model. Write forward-only models using inline layer calls — no `nn.Module` boilerplate — and get back a proper `nn.Module` with full parameter management and other goodies.

## Table of Contents

- [Table of Contents](#table-of-contents)
- [✨ Why blaze?](#-why-blaze)
  - [Example: Convolutional network](#example-convolutional-network)
  - [🗑️ What gets eliminated](#️-what-gets-eliminated)
- [🚀 Features](#-features)
- [📦 Installation](#-installation)
- [🧑‍💻 Quickstart](#-quickstart)
- [📖 Core concepts](#-core-concepts)
  - [🔁 Two-phase execution](#-two-phase-execution)
  - [🏋️ Training](#️-training)
  - [🧩 User-defined modules (`bl.Module`)](#-user-defined-modules-blmodule)
  - [🎛️ Raw parameters (`bl.get_parameter`)](#️-raw-parameters-blget_parameter)
  - [💾 Non-trainable state (`bl.get_state` / `bl.set_state`)](#-non-trainable-state-blget_state--blset_state)
  - [🔌 Wrapping existing modules (`bl.wrap`)](#-wrapping-existing-modules-blwrap)
  - [🏷️ Custom names](#️-custom-names)
  - [🏗️ Custom initialization (`init_fn`)](#️-custom-initialization-init_fn)
  - [⚡ Compilation](#-compilation)
  - [🧱 Available layers](#-available-layers)
- [🔗 Related projects](#-related-projects)
- [📄 License](#-license)

## ✨ Why blaze?

Traditional way to define PyTorch models makes you write every layer **twice** — declared in `__init__`, used in `forward` — and requires naming each one (arch nemesis of programmers), drastically slowing down iterative prototyping. You also have to manually track and hardcode input sizes for every layer, even when they could be trivially computed from the previous layer's output. `bl` removes all of that: layers are written once, inline, exactly where they're used, and input sizes can be read straight from the live tensor via `x.shape` during `init()`.

### Example: Convolutional network

**Traditional PyTorch:**

```python
class ConvNet(nn.Module):
    def __init__(self):          # ← boilerplate you must write
        super().__init__()       # ← boilerplate you must write
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)  # ← named here...
        self.bn1   = nn.BatchNorm2d(32)               # ← named here...
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)  # ← named here...
        self.bn2   = nn.BatchNorm2d(64)               # ← named here...
        self.pool  = nn.AdaptiveAvgPool2d(1)          # ← named here...
        self.fc    = nn.Linear(64, 10)                # ← named here & must know input size!

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))  # ← ...and used here
        x = F.relu(self.bn2(self.conv2(x)))  # ← ...and used here
        x = self.pool(x).flatten(1)          # ← ...and used here
        return self.fc(x)                    # ← ...and used here, what's the output dim again?

model = ConvNet()
```

**Blaze:**

```python
# No class. No __init__. No self. No invented names. Only logic.
def forward(x):
    x = F.relu(bl.BatchNorm2d(32)(bl.Conv2d(3, 32, 3, padding=1)(x)))
    x = F.relu(bl.BatchNorm2d(64)(bl.Conv2d(32, 64, 3, padding=1)(x)))
    x = bl.AdaptiveAvgPool2d(1)(x).flatten(1)
    return bl.Linear(x.shape[-1], 10)(x)  # ← input size computed from the tensor

model = bl.transform(forward)
model.init(torch.randn(1, 3, 32, 32)) # discovers and creates all modules
```

### 🗑️ What gets eliminated

| PyTorch requirement | With blaze |
|---|---|
| `class MyModel(nn.Module)` | Plain function or thin `bl.Module` subclass |
| `def __init__(self)` | Not needed |
| `super().__init__()` | Not needed |
| `self.layer = nn.Linear(...)` | Not needed — layers are created inline |
| Inventing a name for every layer | Auto-derived from class name and deduplicated |
| `nn.ModuleList` / `nn.ModuleDict` for dynamic structure | A plain Python loop or dict |
| Passing hyperparameters through `__init__` to store for `forward` | Just use them directly in the function |
| Manually tracking input sizes across layers | Use `x.shape` — sizes are inferred from the live tensor during `init()` |

---

## 🚀 Features

- 🧹 **No `nn.Module` boilerplate** — define models as plain functions; layers are called inline.
- 🔌 **Drop-in compatible** — `BlazeModule` is a standard `nn.Module`; training loops, optimizers, `state_dict`, and deployment code need no changes.
- ⚙️ **Automatic parameter management** — weights are discovered on the first `init()` pass, reused on every subsequent call, and organized into hierarchical paths (e.g. `"block.linear"`) derived automatically from class names (or overridden with `name=`).
- 📐 **Dynamic size inference** — since `init()` runs with a real tensor, layer sizes can be computed from `x.shape` instead of hardcoded — no more manually tracking dimensions across layers.
- 🧩 **Composable modules** — subclass `bl.Module` to build reusable components; scopes nest correctly no matter how deep.
- 🎛️ **Raw parameters** — `get_parameter()` creates a learnable `nn.Parameter` scoped to the current path, without any surrounding module.
- 💾 **Non-trainable state** — `get_state()` / `set_state()` create and update buffer tensors (analogous to `hk.get_state` / `hk.set_state`) that are tracked by the module but excluded from gradient updates.
- ⚡ **`torch.jit.script/trace` and `torch.compile`** (experimental) — after `init()`, models can be compiled for performance and deployment.
- 🧱 **Built-in layer wrappers** — covers linear, conv, norm, pooling, activation, dropout, recurrent, embedding, attention, transformer, and shape layers.
- 🔌 **Seamless integration** — `bl.wrap` lets you use any existing `nn.Module`, pretrained model, or third-party layer directly inside a blaze function.

## 📦 Installation

```bash
pip install blaze-torch
```

## 🧑‍💻 Quickstart
```python
import torch
import blaze as bl

def forward(x, hidden_dim):
    x = bl.Linear(x.shape[-1], hidden_dim)(x)
    x = bl.ReLU()(x)
    x = bl.Linear(hidden_dim, 1)(x)
    return x

model = bl.transform(forward, hidden_dim=128) # pass kwargs as static hyperparameters
model.init(torch.randn(4, 10))   # discovers and creates all modules

out = model(torch.randn(4, 10))  # normal nn.Module usage
```

## 📖 Core concepts

### 🔁 Two-phase execution

`bl.transform` wraps your forward function. Calling `.init(sample_input)` runs an **INIT** pass that discovers every layer call and registers it into an internal registry keyed by its hierarchical path (e.g. `"block.linear"`). Subsequent calls run in **APPLY** mode, reusing the registered modules by call order.

```python
model = bl.transform(forward)                # empty model
model.init(torch.randn(batch, in_features))  # INIT pass — creates weights
output = model(x)                            # APPLY pass — reuses weights
```

### 🏋️ Training

`BlazeModule` is a standard `nn.Module` — use any PyTorch optimizer:

```python
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for x, y in dataloader:
    optimizer.zero_grad()
    loss = criterion(model(x), y)
    loss.backward()
    optimizer.step()
```

### 🧩 User-defined modules (`bl.Module`)

Subclass `bl.Module` and implement `__call__` to group layers into reusable components with same scope. The class name is automatically converted to snake_case for scoping, and repeated instantiations are deduplicated with a numeric suffix.

```python
class MLP(bl.Module):
    def __call__(self, x):
        x = bl.Linear(x.shape[-1], 128)(x)
        x = bl.GELU()(x)
        x = bl.Linear(128, x.shape[-1])(x)
        return x

def forward(x):
    x = MLP()(x)   # parameter names: "mlp.linear", "mlp.gelu", "mlp.linear_1"
    x = MLP()(x)   # parameter names: "mlp_1.linear", ...
    return x
```

### 🎛️ Raw parameters (`bl.get_parameter`)

Create a learnable `nn.Parameter` directly, scoped to the current name context. Analogous to `hk.get_parameter`.

```python
def forward(x):
    scale = bl.get_parameter("scale", (x.shape[-1],), init_fn=torch.ones)
    bias  = bl.get_parameter("bias",  (x.shape[-1],), init_fn=torch.zeros)
    return x * scale + bias
```

### 💾 Non-trainable state (`bl.get_state` / `bl.set_state`)

Create and update buffer tensors (non-trainable, tracked by the module). Analogous to `hk.get_state` / `hk.set_state`.

```python
def forward(x):
    running_mean = bl.get_state("running_mean", (x.shape[-1],), init_fn=torch.zeros)
    bl.set_state("running_mean", running_mean * 0.9 + x.mean(0) * 0.1)
    return x - running_mean
```

### 🔌 Wrapping existing modules (`bl.wrap`)

Have an existing `nn.Module`, a pretrained model, or a third-party layer? `bl.wrap` lets you use it directly inside a blaze function — no subclassing or redefining needed:

```python
def forward(x):
    encoder = bl.wrap(lambda: torchvision.models.resnet18(pretrained=True))
    x = encoder(x)
    x = bl.Linear(x.shape[-1], 10)(x)
    return x
```

The factory is called once during `init()` and the resulting module is reused on every subsequent forward call. Pass `name=` to override the default registry key:

```python
x = bl.wrap(lambda: nn.Linear(10, 64), name="encoder")(x)
```

### 🏷️ Custom names

Pass `name=` to any layer call to override the auto-derived name:

```python
def forward(x):
    x = bl.Linear(10, 64, name="encoder")(x)
    x = bl.Linear(64, 10, name="decoder")(x)
    return x
```

### 🏗️ Custom initialization (`init_fn`)

All layer wrappers and `bl.Module` subclasses accept an `init_fn=` callback that runs once during `init()` and is skipped on subsequent forward calls:

```python
def forward(x):
    x = bl.Linear(10, 64, init_fn=lambda m: nn.init.xavier_uniform_(m.weight))(x)
    x = bl.Conv2d(64, 32, 3, padding=1, init_fn=lambda m: nn.init.kaiming_normal_(m.weight))(x)
    x = bl.BatchNorm2d(32, init_fn=lambda m: nn.init.ones_(m.weight))(x)
    return x
```

Works with `bl.Module` subclasses too:

```python
class Block(bl.Module):
    def __init__(self, dim, init_fn=None):
        super().__init__(init_fn=init_fn)
        self.dim = dim

    def __call__(self, x):
        return bl.Linear(x.shape[-1], self.dim)(x)

def forward(x):
    return Block(dim=64, init_fn=my_custom_init)(x)
```

### ⚡ Compilation

After `.init()`, models work with `torch.jit.trace`, `torch.jit.script`, and `torch.compile`:

```python
model = bl.transform(forward)
model.init(torch.randn(2, 10))

traced  = torch.jit.trace(model, torch.randn(2, 10))
scripted = torch.jit.script(model)
compiled = torch.compile(model)
```

### 🧱 Available layers

All wrappers accept the same arguments as their `torch.nn` counterparts, along with optional arguments `name` and `init_fn`.

| Category | Layers |
|---|---|
| Linear | `Linear`, `Bilinear` |
| Conv | `Conv1d/2d/3d`, `ConvTranspose1d/2d/3d` |
| Norm | `BatchNorm1d/2d/3d`, `SyncBatchNorm`, `InstanceNorm1d/2d/3d`, `LayerNorm`, `GroupNorm`, `RMSNorm` |
| Pooling | `MaxPool1d/2d/3d`, `AvgPool1d/2d/3d`, `AdaptiveAvgPool1d/2d/3d`, `AdaptiveMaxPool1d/2d/3d` |
| Activation | `ReLU`, `ReLU6`, `LeakyReLU`, `PReLU`, `ELU`, `SELU`, `CELU`, `GELU`, `Mish`, `SiLU`, `Tanh`, `Sigmoid`, `Hardsigmoid`, `Hardswish`, `Softmax`, `LogSoftmax`, `Softplus` |
| Dropout | `Dropout`, `Dropout1d/2d/3d`, `AlphaDropout` |
| Recurrent | `LSTM`, `GRU`, `RNN`, `LSTMCell`, `GRUCell`, `RNNCell` |
| Embedding | `Embedding`, `EmbeddingBag` |
| Attention | `MultiheadAttention` |
| Transformer | `Transformer`, `TransformerEncoder`, `TransformerDecoder`, `TransformerEncoderLayer`, `TransformerDecoderLayer` |
| Shape | `Flatten`, `Unflatten`, `Upsample`, `PixelShuffle`, `PixelUnshuffle` |
| Misc | `Identity` |

## 🔗 Related projects

| Project | Framework | Description |
|---|---|---|
| [dm-haiku](https://github.com/google-deepmind/dm-haiku) | JAX | The original inspiration. Transforms stateful `hk.Module` code into pure `(init, apply)` function pairs via `hk.transform`. |
| [Flax NNX](https://github.com/google/flax) | JAX | Google's neural network library for JAX. The newer NNX API uses PyTorch-style `__init__`/`__call__` with mutable state; the older Linen API is closer to Haiku's functional style. |
| [Equinox](https://github.com/patrick-kidger/equinox) | JAX | Neural networks as callable PyTrees. Models are plain Python dataclasses; parameters live in the tree rather than a separate registry, making them compatible with `jax.jit`/`jax.grad` directly. |
| [torch.func](https://docs.pytorch.org/docs/stable/func.html) | PyTorch | PyTorch's built-in functional transforms (formerly `functorch`). `torch.func.functional_call` lets you call an existing `nn.Module` with an explicit parameter dict, enabling per-sample gradients, meta-learning, etc. |
| [PyTorch Lightning](https://github.com/Lightning-AI/pytorch-lightning) | PyTorch | Training loop abstraction over `nn.Module`. Reduces boilerplate around the train/val/test cycle but keeps the imperative `nn.Module` programming model. |

## 📄 License

[MIT](LICENSE)
