Metadata-Version: 2.1
Name: reinmax
Version: 0.1.0
Summary: ReinMax Algorithm
Home-page: https://github.com/microsoft/reinmax
Author: Lucas Liu
Author-email: llychinalz@gmail.com
License: MIT
Platform: UNKNOWN
Classifier: Development Status :: 2 - Pre-Alpha
Classifier: Intended Audience :: Developers
Classifier: Natural Language :: English
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Description-Content-Type: text/markdown

![PyTorch](https://img.shields.io/badge/PyTorch-%23EE4C2C.svg?style=flat&logo=PyTorch&logoColor=white)
![PyPI - Python Version](https://img.shields.io/pypi/pyversions/reinmax) 
![GitHub](https://img.shields.io/github/license/microsoft/reinmax) 
![PyPI](https://img.shields.io/pypi/v/reinmax) 

<h2 align="center">ReinMax</h2>
<h4 align="center"> Beyond Straight-Through</h4>

<p align="center">
  <a href="#straight-through-and-how-it-works">Straight-Through</a> •
  <a href="#better-performance-with-negligible-computation-overheads">ReinMax</a> •
  <a href="#how-to-use">How To Use</a> •
  <a href="#examples">Examples</a> •
  <a href="#citation">Citation</a> •
  <a href="https://github.com/microsoft/reinmax/tree/main/LICENSE">License</a>
</p>

[ReinMax]() achieves **second-order** accuracy and is **as fast as** the original Straight-Through, which has first-order accuracy.

<!-- <h4 align="center"><i>Straight-Through and How It Works</i></h4> -->
## Straight-Through and How It Works

Straight-Through (as below) bridges discrete variables (`y_hard`) and back-propagation. 
```python
y_soft = theta.softmax()

# one_hot_multinomial is a non-differentiable function
y_hard = one_hot_multinomial(y_soft) 

# with straight-through, the derivative of s_hard will
# act as if you had `p_soft` in the forward
y_hard = y_soft - y_soft.detach() + y_hard 
```
It is a long-standing mystery on how straight-through works, lefting doubts on many problems like whether we should use:
- `p_soft - p_soft.detach()`,
- ` (theta/tau).softmax() - (theta/tau).softmax().detach()`,
- or what?


<!-- 
<h4 align="center"><i>Better Performance with Negligible Computation Overheads</i></h4> -->
## Better Performance with Negligible Computation Overheads

[We reveal]() that Straight-Through works as a special case of the forward Euler method, a numerical methods with first-order accuracy. 
Inspired by Heun's Method, a numerical method achieving second-order accuracy without requiring Hession or other second-order derivatives, we propose ReinMax, which *approximates gradient with second-order accuracy with negligible computation overheads.*

## How to use?

### install 
```
pip install reinmax
```

### enjoy

```diff
from reinmax import reinmax
...

def forward(self, ...):
...
- y_soft = theta.softmax()
- y_hard = one_hot_multinomial(y_soft) 
- y_hard = y_soft - y_soft.detach() + y_hard 
+ y_hard, y_soft = reinmax(theta)
...
```

## Examples

- [Polynomial Programming]()
- [MNIST-VAE]()
- [ListOps]()

## Citation
Please cite the following papers if you found our model useful. Thanks!

>Liyuan Liu, Xiaodong Liu, Jianfeng Gao, Weizhu Chen, and Jiawei Han (2020). Understanding the Difficulty of Training Transformers. Proc. 2020 Conf. on Empirical Methods in Natural Language Processing (EMNLP'20).
```
@inproceedings{liu2020admin,
  title={Understanding the Difficulty of Training Transformers},
  author = {Liu, Liyuan and Liu, Xiaodong and Gao, Jianfeng and Chen, Weizhu and Han, Jiawei},
  booktitle = {Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP 2020)},
  year={2020}
}
```

