Metadata-Version: 2.1
Name: e2e_sae
Version: 2.0.0
Summary: Repo for training sparse autoencoders end-to-end
Project-URL: repository, https://github.com/ApolloResearch/e2e_sae
Requires-Python: >=3.11
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch ~=2.2.0
Requires-Dist: torchvision ~=0.17.0
Requires-Dist: einops ~=0.7.0
Requires-Dist: pydantic ~=2.0
Requires-Dist: wandb ~=0.16.2
Requires-Dist: fire ~=0.5.0
Requires-Dist: tqdm ~=4.66.1
Requires-Dist: pytest ~=8.1.2
Requires-Dist: ipykernel ~=6.29.0
Requires-Dist: transformer-lens ~=1.14.0
Requires-Dist: jaxtyping ~=0.2.25
Requires-Dist: python-dotenv ~=1.0.1
Requires-Dist: zstandard ~=0.22.0
Requires-Dist: matplotlib ~=3.5.3
Requires-Dist: seaborn ~=0.13.2
Requires-Dist: umap-learn ~=0.5.6
Requires-Dist: tenacity ~=8.2.3
Requires-Dist: statsmodels ~=0.14.2
Requires-Dist: automated-interpretability ~=0.0.3
Provides-Extra: dev
Requires-Dist: ruff ~=0.1.14 ; extra == 'dev'
Requires-Dist: pyright ==1.1.362 ; extra == 'dev'
Requires-Dist: pre-commit ~=3.6.0 ; extra == 'dev'

# e2e_sae

This library is used to train and evaluate Sparse Autoencoders (SAEs). It handles the following
training types:
- e2e (end-to-end): Loss function includes sparsity and final model kl_divergence.
- e2e + downstream reconstruction: Loss function includes sparsity, final model kl_divergence, and MSE
    at downstream layers.
- local (i.e. vanilla SAEs): Loss function includes sparsity and MSE at the SAE layer
- Any combination of the above.

See our [paper](https://publications.apolloresearch.ai/end_to_end_sparse_dictionary_learning) which argues for training SAEs e2e rather than locally. All SAEs presented in the paper can be found at https://wandb.ai/sparsify/gpt2 and can be loaded using this library.

## Usage
### Installation
```bash
pip install e2e_sae
```

### Train SAEs on any [TransformerLens](https://github.com/neelnanda-io/TransformerLens) model
If you would like to track your run with Weights and Biases, place your api key and entity name in
a new file called `.env`. An example is provided in [.env.example](.env.example).

Create a config file (see gpt2 configs [here](e2e_sae/scripts/train_tlens_saes/) for examples).
Then run
```bash
python e2e_sae/scripts/train_tlens_saes/run_train_tlens_saes.py <path_to_config>
```

If using a Colab notebook, see [this example](demos/train_saes.ipynb).

Sample wandb sweep configs are provided in [e2e_sae/scripts/train_tlens_saes/](e2e_sae/scripts/train_tlens_saes/).

The library also contains scripts for training mlps and SAEs on mlps, as well as training
custom transformerlens models and SAEs on these models (see [here](e2e_sae/scripts/)).
### Load a Pre-trained SAE
You can load any pre-trained SAE (and accompanying TransformerLens model) trained using this library
from Weights and Biases or locally by running
```python
from e2e_sae import SAETransformer
model = SAETransformer.from_wandb("<entity/project/run_id>")
# or, if stored locally
model = SAETransformer.from_local_path("/path/to/checkpoint/dir") 
```
All runs in our
[paper](https://publications.apolloresearch.ai/end_to_end_sparse_dictionary_learning)
can be loaded this way (e.g.[sparsify/gpt2/tvj2owza](https://wandb.ai/sparsify/gpt2/runs/tvj2owza)).


This will instantiate a `SAETransformer` class, which contains a TransformerLens model with SAEs
attached. To do a forward pass without SAEs, use the `forward_raw` method, to do a forward pass with
SAEs, use the `forward` method (or simply call the SAETansformer instance).

The dictionary elements of an SAE can be accessed via `SAE.dict_elements`. This is will normalize
the decoder elements to have norm 1.

### Analysis
To reproduce all of the analysis in our
[paper](https://publications.apolloresearch.ai/end_to_end_sparse_dictionary_learning) use the
scripts in `e2e_sae/scripts/analysis/`.

## Contributing
Developer dependencies are installed with `make install-dev`, which will also install pre-commit
hooks.

Suggested extensions and settings for VSCode are provided in `.vscode/`. To use the suggested
settings, copy `.vscode/settings-example.json` to `.vscode/settings.json`.

There are various `make` commands that may be helpful

```bash
make check  # Run pre-commit checks on all files (i.e. pyright, ruff linter, and ruff formatter)
make type  # Run pyright on all files
make format  # Run ruff linter and formatter on all files
make test  # Run tests that aren't marked `slow`
make test-all  # Run all tests
```

This library is maintained by [Dan Braun](https://danbraunai.github.io/).

Join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-2hk7rcm8g-IIuaxpte_1GHp5joc~1kww)
to chat about this library and other projects in the space!
