Metadata-Version: 2.4
Name: trim-transformer
Version: 0.1.3
Summary: A linear-attention transformer implementation with KV caching.
Home-page: https://github.com/emanuel-nuclearsoftware/trim-transformer
Author: Emanuel Gordis
Author-email: emanuel@nuclearsoftware.com
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Requires-Python: >=3.6
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch
Dynamic: author
Dynamic: author-email
Dynamic: classifier
Dynamic: description
Dynamic: description-content-type
Dynamic: home-page
Dynamic: license-file
Dynamic: requires-dist
Dynamic: requires-python
Dynamic: summary

# Trim Transformer

`trim-transformer` is a lightweight PyPI package that replicates the familiar interface of `torch.nn.TransformerEncoder`, but with an attention function of the form Attn(Q,K,V) = QK^TV, which we call multi-linear attention. This implementation has time complexity O(nd^2), where n is the sequence length and d is the model dimension. Since the time complexity is linear in the sequence length, this implementation is well suited for high sequence length tasks. Attention in this form has shown success in operator learning tasks, see [Choose a Transformer: Fourier or Galerkin](https://arxiv.org/abs/2105.14995).

This implementation is particularly relevent for training physics models where high sequence length can come from large grid sizes, long time periods, or both.

Additionally, this implementation supports key-value caching for inference that is also linear in the number of tokens generated. Finally, this implementation supports custom weight initialization functions for the query, key, and value projection matrices, and custom normalization layers for the query, key,
and value activations.

![PyPI](https://img.shields.io/pypi/v/trim-transformer?color=%2334D058&logo=pypi) ![License](https://img.shields.io/github/license/emanuel-nuclearsoftware/trim-transformer)

---

## Installation

The package is published on PyPI and only depends on PyTorch:

```bash
pip install trim-transformer
```

Alternatively, install the latest commit from GitHub:

```bash
pip install git+https://github.com/eg-trim/trim-transformer.git
```

---
## Benchmarks

Below are some benchmark plots demonstrating model performance and resource usage on the Navier-Stokes dataset from https://arxiv.org/abs/2010.08895:

The Trim Transformer achives more than 90% reduction in memory usage compared to a standard Pytorch transformer using softmax attention and 3.5x faster time per epoch while maintaining very similar validation loss. As grid size and sequence length increase these gains become even more drastic.
![Memory Usage](plots/mem_use.png)

![Time per Epoch](plots/time:epoch.png)
![Training Loss](plots/loss.png)

---
## Quickstart

```python
import torch
from trim_transformer.transformer_layers import TrimTransformerEncoderLayer, TrimTransformerEncoder

layer = TrimTransformerEncoderLayer(d_model=EMBED_DIM, nhead=NUM_HEADS, batch_first=True)
model = TrimTransformerEncoder(layer, num_layers=NUM_LAYERS)

x = torch.randn(8, 2048, 512)  # (batch, seq_len, dim)

# Standard forward pass (causal mask optional)
out = model(x, is_causal=True)  # (batch, seq_len, dim)
```

See [tutorial_notebook.ipynb](tutorial_notebook.ipynb) for demonstrations of each feature. And see [trim_vs_softmax.ipynb](trim_vs_softmax.ipynb) for an example of a complete pipeline and a comparison to a PyTorch baseline.

## Masking

A significant departure from PyTorch syntax is the structure of the mask. Multi-linear attention with arbitrary boolean masks cannot be computed in time linear in the sequence length. Instead, this package supports masks such that the i-th query attends to all keys up to index m_i. Such masks can be specified by an integer array of length n, with values in [0, n-1], where n is the sequence length.

For example, a causal mask of length `n` is given by `torch.arange(n)`. To illustrate further, consider the one-dimensional mask `[2, 0, 1]`. This corresponds to the following two-dimensional mask, where, following the PyTorch convention, `True` indicates that an element is masked.

|             | Key 0 | Key 1 | Key 2 |
|-------------|:-----:|:-----:|:-----:|
| **Query 0** | `False` | `False` | `False` |
| **Query 1** | `False` | `True`  | `True`  |
| **Query 2** | `False` | `False` | `True`  |

## Key-value caching

Inference with key-value caching can be performed with a simple loop.

```python
def generate(model, initial_tokens, num_new_tokens=5):
    """Autoregressive generation with caching"""
    model.eval()  # optional
    generated_sequence = []
    new_token = initial_tokens.clone()

    for _ in range(num_new_tokens):
        # Must set use_kv_cache=True and update_kv_cache=True to use key-value caching.
        # use_kv_cache=True means that the model will key-value cache that is already stored.
        # update_kv_cache=True means that the model will update the key-value cache with the
        # keys and values from the new token.
        output = model(new_token, is_causal=True, use_kv_cache=True, update_kv_cache=True)
        generated_sequence.append(output)

    # Clear the key-value cache after generation. If you don't clear the cache, then if
    # use_kv_cache=True in the future, the model will continue to use the key-value cache
    # on future forward passes.
    model.clear_kv_cache()
    return generated_sequence
```

---

## License

`trim-transformer` is released under the MIT License.  See [LICENSE](LICENSE) for the full text.
