Metadata-Version: 2.4
Name: gns-pytorch
Version: 0.1.2
Summary: Simple Gradient Noise Scale (GNS) calculation for PyTorch
Author-email: elyx <elio@pascarelli.com>
Requires-Python: >=3.9
Requires-Dist: torch>=2.0.0
Description-Content-Type: text/markdown

# GNS PyTorch

This is the easiest way to calculate GNS (Gradient Noise Scale) for your PyTorch models. No hooks, gradient accumulation, or multi-GPU setup needed. Just pass in your per-example losses and model.

## What's GNS?

GNS measures gradient noise in your training. See <https://arxiv.org/pdf/1812.06162> and <https://openreview.net/forum?id=xINTMAvPQA>

## Install

```bash
pip install gns-pytorch
```

## Usage

Simple usage:

```python
from gns_pytorch import compute_gns
import torch

model = YourModel()
optimizer = torch.optim.Adam(model.parameters())

def training_step(batch):
    x, y = batch
    logits = model(x)
    per_example_losses = torch.nn.functional.cross_entropy(logits, y, reduction='none')
    
    if global_step % 100 == 0:
        gns_value = compute_gns(per_example_losses, model)
        gns_ema = 0.9 * gns_ema + 0.1 * gns_value
        print(f"Current GNS (EMA): {gns_ema}")
    
    loss = per_example_losses.mean()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
```

## Adaptive Batch Size Scheduling

With accurate GNS you can schedule your batch size (using gradient accumulation) to always be critical / optimal throughout training, massively boosting convergence and sample efficiency. This is similar to what deepseek-v3 did.

## Tips

- Call `compute_gns` every N steps (like 100+) to avoid overhead
- Use an EMA on the GNS values since they are very noisy
- The `param_percentage` param lets you sample a subset of model parameters for faster computation
- Enable vmap with `use_vmap=True` to speed up computation by parallelizing per-example gradients (unfortunately, PyTorch's vmap isn't composable with flex attention and torch.compile yet)
