Metadata-Version: 2.1
Name: srf-attention
Version: 1.0.16
Summary: Simplex random feature attention in PyTorch for both training and inference
Home-page: https://github.com/alexjlevenston/srf-attention
Author: Alex Levenston
Author-email: alexlevenston2021@gmail.com
License: MIT
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Requires-Python: >=3.6
Description-Content-Type: text/markdown
License-File: LICENSE

# srf-attention
Simplex Random Feature attention, in PyTorch

## A Prelude
### Why? What? Huh?
Softmax attention ate the world. But now it's eating our wallets. Luckily enough for us wordcels, those nifty shape rotators realized that even though softmax isn't stationary, it's amenable to Monte Carlo methods. Translation: we can retrofit pretrained LLMs for recurrent inference! Smarter men than I proceeded to publish [this](https://arxiv.org/abs/2009.14794), [this](https://arxiv.org/abs/2205.15317), and [that](https://arxiv.org/abs/2301.13856). This repo is a PyTorch implementation of "that", with some syntactic sugar added to aid digestion. It's intended to be used for [ERPTRI](https://github.com/alexjlevenston/erptri-train), but do with it what you will.

### What is this good for?
Well, it really ain't for you open-sourcerers. You're bottlenecked by weight I/O. But for those running large-batch inference, e.g as part of a synthetic data pipeline, KV cache I/O dominates the cost for sequences > ~700 tokens. [ERPTRI](https://github.com/alexjlevenston/erptri-train) efficiently [sic] drops the KV cache size of any pretrained auto-regressive Transformer from $`O(LD)`$ to $`O(D^2)`$. This repo implements the PyTorch modules necessary for the fine-tuning phase of ERPTRI, and for efficient inference.

### Next steps
Venture forth and conquer. But first, fine-tune under an ordinary NLL loss on the original pretraining distribution, after performing the [appropriate](#Usage) model surgery. [Here's](https://huggingface.co/datasets/reversebutlerianjihad/AnorexicPajama) the RedPajama subset that was used for the Llama 2 retrofit.

## Installation
Insta-wheel:
```bash
pip install git+https://github.com/alexjlevenston/srf-attention
```

## Usage
```python
import torch
from srf_attention import Attention

device = 'cpu'

B, H, L, D = (1, 8, 1024, 128)

q, k, v = [torch.randn(B, H, L, D).requires_grad_() for _ in range(3)]

# CHUNK_SIZE controls the memory/compute tradeoff of the attention computation
# Controls memory/compute tradeoff
CHUNK_SIZE=1024

# Simplex Random Feature (SRF) Attention module
# All intermediate computations done in FP32, but cached values are FP16.
# Recomputes the attention matrix in the backward pass instead of storing it:
attn = Attention(d=D, n_features=D, causal=True, device=device)

# During fine-tuning, replace your softmax attention function with this:
o = attn(q, k, v, mode='train', attn_fn='torch', chunk_size=CHUNK_SIZE)

# Use 1 instance for each layer,
# and disable auto-redraw prior to beginning training:
attn.redraw_on_call_(False)

# On each training step, call redraw_() to resample the random features:
attn.redraw_()

# That's it! Now just fine-tune.
```

## Example
Here's an example, using the HF Transformers [diff](https://github.com/alexjlevenston/transformers-llama-srf) I wrote to retrofit Llama with SRF attention:
```python
# Make sure TILE_SIZE env var is set, I use TILE_SIZE=256
import torch
# install using `pip install git+https://github.com/alexjlevenston/transformers-llama-srf`
import transformers
from transformers import LlamaForCausalLM, LlamaTokenizer

tokenizer = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
model = LlamaForCausalLM.from_pretrained('meta-llama/Llama-2-7b-hf')

for module in model.modules():
  if isinstance(module, transformers.models.llama.modeling_llama.LlamaAttention):
    module.use_fast_attn_(True)
    module.attn_fn.redraw_on_call_(False)

def resample_rfs(model):
  for module in model.modules():
    if isinstance(module, transformers.models.llama.modeling_llama.LlamaAttention):
      module.attn_fn.redraw_(next(model.parameters()).device)

optimizer = YourOptimizerHere()

for step, batch in enumerate(imaginary_dataset):
  inputs, targets = batch
  # Always resample random features manually,
  # because auto-resampling causes issues with checkpointing
  resample_rfs(model)
  outputs = model(inputs)
  logits = outputs.logits.reshape(-1, outputs.logits.shape[-1])
  loss = torch.nn.functional.cross_entropy(logits, targets['input_ids'].reshape(-1))
  loss.backward()
  optimizer.step()
```
