Metadata-Version: 2.1
Name: jaxhelper
Version: 0.0.4
Summary: Basic tools and helpers for Jax
Home-page: https://github.com/clashluke/jaxhelper
Author: Lucas Nestler
Author-email: github.jaxhelper@nestler.sh
License: BSD
Platform: UNKNOWN
Classifier: Development Status :: 5 - Production/Stable
Classifier: License :: OSI Approved :: BSD License
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Topic :: Software Development :: Libraries
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Classifier: Intended Audience :: Developers
Requires-Python: >=3.7
Description-Content-Type: text/markdown

# JaxHelper

Basic tools and helpers for Jax

## Getting Started

### Installation

```BASH
python3 -m pip install jaxhelper
```

### Explanation

This repository contains basic helper functions I use every day.\
Here are some highlights:

* **remat**: function decorator to rematerialize ("activation checkpointing") hidden states during backward pass
* **softmax**:
    * exp in fp32 and matmul in bf16 (-> improved convergence and speed)
    * fewer stored intermediates yet faster gradient
* **attention**:
    * recomputation of hidden states
    * memory of O(K * N) rather than O(N^2) thanks
      to [Self-attention Does Not Need O(n2) Memory](https://arxiv.org/abs/2112.05682). (No slowdown)


