Metadata-Version: 2.4
Name: smartclip
Version: 0.1.0
Summary: Adaptive gradient clipping for PyTorch, TensorFlow, and JAX
Project-URL: Homepage, https://github.com/your-org/smartclip
Project-URL: Issues, https://github.com/your-org/smartclip/issues
Author: SmartClip Maintainers
License: MIT
License-File: LICENSE
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Typing :: Typed
Requires-Python: >=3.9
Provides-Extra: bench
Requires-Dist: cycler>=0.12; extra == 'bench'
Requires-Dist: kiwisolver==1.4.7; extra == 'bench'
Requires-Dist: matplotlib>=3.8; extra == 'bench'
Requires-Dist: python-dateutil>=2.8; extra == 'bench'
Provides-Extra: dev
Requires-Dist: build>=1.2; extra == 'dev'
Requires-Dist: mypy>=1.10; extra == 'dev'
Requires-Dist: pre-commit>=3.7; extra == 'dev'
Requires-Dist: ruff>=0.6; extra == 'dev'
Requires-Dist: types-setuptools; extra == 'dev'
Provides-Extra: docs
Requires-Dist: mkdocs-material>=9.5; extra == 'docs'
Requires-Dist: mkdocs>=1.6; extra == 'docs'
Requires-Dist: mkdocstrings[python]>=0.25; extra == 'docs'
Provides-Extra: jax
Requires-Dist: flax>=0.8; extra == 'jax'
Requires-Dist: optax>=0.2; extra == 'jax'
Provides-Extra: test
Requires-Dist: hypothesis>=6; extra == 'test'
Requires-Dist: numpy>=1.26; extra == 'test'
Requires-Dist: pytest-cov>=5; extra == 'test'
Requires-Dist: pytest-randomly>=3; extra == 'test'
Requires-Dist: pytest-timeout>=2; extra == 'test'
Requires-Dist: pytest-xdist>=3; extra == 'test'
Requires-Dist: pytest>=8; extra == 'test'
Provides-Extra: tf
Provides-Extra: torch
Requires-Dist: pytorch-lightning<3,>=2.3; extra == 'torch'
Requires-Dist: transformers<5,>=4.42; extra == 'torch'
Description-Content-Type: text/markdown

# smartclip

[![PyPI version](https://img.shields.io/pypi/v/smartclip.svg)](https://pypi.org/project/smartclip/)
[![CI](https://github.com/stefangordon/smartclip/actions/workflows/ci.yml/badge.svg?branch=main)](https://github.com/stefangordon/smartclip/actions/workflows/ci.yml)
[![Docs Build](https://github.com/stefangordon/smartclip/actions/workflows/docs.yml/badge.svg?branch=main)](https://github.com/stefangordon/smartclip/actions/workflows/docs.yml)
[![Docs](https://img.shields.io/badge/docs-mkdocs%20material-blue)](https://stefangordon.github.io/smartclip)
[![Python Versions](https://img.shields.io/pypi/pyversions/smartclip.svg)](https://pypi.org/project/smartclip/)
[![License: MIT](https://img.shields.io/badge/license-MIT-green.svg)](LICENSE)

Adaptive gradient clipping for PyTorch, TensorFlow, and JAX.

SmartClip keeps training stable with adaptive, per-step clipping you can enable in one line of code.

See the full [documentation](https://stefangordon.github.io/smartclip/) for details of the algorithms, framework usage examples, and logging metrics.

## Supported Algorithms

- AutoClip — Seetharaman et al., 2020 (MLSP). Adaptive percentile-based clipping of gradient norms.
  - [AutoClip: Adaptive Gradient Clipping for Source Separation Networks (arXiv:2007.14469)](https://arxiv.org/abs/2007.14469)
- Adaptive Gradient Clipping (AGC, NFNets-style) — Brock, De, Smith, 2021. Threshold scales with parameter norm.
  - [High-Performance Large-Scale Image Recognition Without Normalization (arXiv:2102.06171)](https://arxiv.org/abs/2102.06171)
- Z-Score clipping (EMA mean/std) — standard z-score thresholding using streaming mean/variance

  - `zmax` controls how aggressive clipping is: threshold is `mean + zmax * std` over recent norms. Higher `zmax` clips less (more tolerant), lower clips more (more aggressive). Start at `zmax=3.0`; try `2.0–2.5` if you see instability from spikes, or `3.5–4.0` if training seems over‑clipped.

## Install

```bash
pip install smartclip
```

Optional extras provide helpers for specific frameworks (install framework wheels first per vendor docs):

```bash
pip install "smartclip[torch]"    # PyTorch + Lightning/Transformers helpers
pip install "smartclip[tf]"       # TensorFlow/Keras helpers
pip install "smartclip[jax]"      # JAX/Flax/Optax helpers
```

## Quickstart

### PyTorch

```python
import torch
import smartclip as sc

model = MyModel().to("cpu")
opt = torch.optim.AdamW(model.parameters(), lr=3e-4)

with sc.clip_context(model, opt):  # Defaults to AutoClip
    for x, y in loader:
        opt.zero_grad(set_to_none=True)
        loss = model(x).loss_fn(y)
        loss.backward()
        opt.step()  # clipped automatically
```

### TensorFlow/Keras

```python
import tensorflow as tf
import smartclip as sc

model = MyModel()
opt = tf.keras.optimizers.Adam(3e-4)

with sc.clip_context(model, opt, clipper=sc.ZScoreClip(zmax=3.0)):  # Use the zscore algorithm
    model.fit(ds, epochs=5)
```

### JAX/Optax

```python
import jax
import optax
from flax import linen as nn
import smartclip as sc

model = MyModel()  # Flax Module
tx = optax.adam(3e-4)

with sc.clip_context(model, tx):  # wraps tx.update
    grads = jax.grad(loss_fn)(params, batch)
    updates, opt_state = tx.update(grads, opt_state, params)  # clipped automatically
    params = optax.apply_updates(params, updates)
```

See documentation for full guides for TensorFlow, JAX, Lightning, Keras, and HF Trainer.


## Contributing

We welcome issues and pull requests. See `contribute.md` for developer setup, testing, docs, and release workflows.

## License

MIT
