Metadata-Version: 2.4
Name: reverse_distillation
Version: 0.1.0
Summary: Reverse Distillation for Protein Language Models
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: accelerate==1.7.0
Requires-Dist: bio==1.8.0
Requires-Dist: biopython==1.85
Requires-Dist: biotite==1.2.0
Requires-Dist: biotraj==1.2.2
Requires-Dist: Brotli==1.1.0
Requires-Dist: datasets
Requires-Dist: decorator==5.2.1
Requires-Dist: dotenv==0.9.9
Requires-Dist: einops==0.8.1
Requires-Dist: fair-esm==2.0.0
Requires-Dist: fbpca==1.0
Requires-Dist: filelock==3.18.0
Requires-Dist: fsspec==2025.3.2
Requires-Dist: gprofiler-official==1.0.0
Requires-Dist: h5py==3.13.0
Requires-Dist: huggingface-hub==0.32.3
Requires-Dist: hydra-core==1.3.2
Requires-Dist: hydra-optuna-sweeper==1.2.0
Requires-Dist: idna==3.10
Requires-Dist: ipykernel>=7.0.1
Requires-Dist: ipython==8.36.0
Requires-Dist: loguru>=0.7.3
Requires-Dist: memory-profiler==0.61.0
Requires-Dist: mygene==3.2.2
Requires-Dist: networkx==3.4.2
Requires-Dist: numpy==2.2.5
Requires-Dist: nvidia-cublas-cu12==12.6.4.1; sys_platform == "linux"
Requires-Dist: nvidia-cuda-cupti-cu12==12.6.80; sys_platform == "linux"
Requires-Dist: nvidia-cuda-nvrtc-cu12==12.6.77; sys_platform == "linux"
Requires-Dist: nvidia-cuda-runtime-cu12==12.6.77; sys_platform == "linux"
Requires-Dist: nvidia-cudnn-cu12==9.5.1.17; sys_platform == "linux"
Requires-Dist: nvidia-cufft-cu12==11.3.0.4; sys_platform == "linux"
Requires-Dist: nvidia-cufile-cu12==1.11.1.6; sys_platform == "linux"
Requires-Dist: nvidia-curand-cu12==10.3.7.77; sys_platform == "linux"
Requires-Dist: nvidia-cusolver-cu12==11.7.1.2; sys_platform == "linux"
Requires-Dist: nvidia-cusparse-cu12==12.5.4.2; sys_platform == "linux"
Requires-Dist: nvidia-cusparselt-cu12==0.6.3; sys_platform == "linux"
Requires-Dist: nvidia-nccl-cu12==2.26.2; sys_platform == "linux"
Requires-Dist: nvidia-nvjitlink-cu12==12.6.85; sys_platform == "linux"
Requires-Dist: nvidia-nvtx-cu12==12.6.77; sys_platform == "linux"
Requires-Dist: omegaconf==2.3.0
Requires-Dist: optuna==2.10.1
Requires-Dist: pandas==2.2.2
Requires-Dist: peft==0.15.2
Requires-Dist: python-dateutil
Requires-Dist: python-dotenv==1.1.0
Requires-Dist: PyYAML==6.0.2
Requires-Dist: ruff==0.14.0
Requires-Dist: safetensors==0.5.3
Requires-Dist: scikit-learn==1.6.1
Requires-Dist: scipy==1.15.3
Requires-Dist: tokenizers==0.21.1
Requires-Dist: torch==2.7.0
Requires-Dist: torchtext==0.18.0
Requires-Dist: torchvision==0.22.0
Requires-Dist: tqdm==4.67.1
Requires-Dist: transformers==4.48.1
Requires-Dist: typing_extensions==4.13.2
Requires-Dist: urllib3==2.4.0
Provides-Extra: dev
Requires-Dist: ruff; extra == "dev"
Requires-Dist: pytest; extra == "dev"
Requires-Dist: pre-commit; extra == "dev"
Dynamic: license-file

# PLM Reverse Distillation

![Reverse Distillation Abstract](rd_abstract.png)

Protein language models (PLMs) scale poorly: for many tasks, mid-sized models often outperform the largest in the same family. **Reverse Distillation** addresses this by decomposing large PLM representations into orthogonal subspaces guided by smaller models of the same family. The resulting embeddings have a Matryoshka-style nested structure — the first *k* dimensions of a larger model's embedding exactly match the smaller model's representation — ensuring larger reverse-distilled models consistently outperform smaller ones.

On ProteinGym benchmarks, reverse-distilled ESM-2 variants outperform their respective baselines at the same embedding dimensionality, with the reverse-distilled 15B model achieving the strongest performance.

## Installation

Requires Python ≥ 3.12 and [`uv`](https://github.com/astral-sh/uv).

```bash
git clone https://github.com/rohitsinghlab/plm_reverse_distillation.git
cd plm_reverse_distillation
uv lock && uv sync
uv pip install -e '.[dev]'
source .venv/bin/activate
```

## Quick Start

See [`inference_tutorial.ipynb`](inference_tutorial.ipynb) for a step-by-step walkthrough of loading pretrained models and extracting embeddings.

Pretrained scalers for all ESM-2 model pairs (8M → 35M → 150M → 650M → 3B → 15B) are available on HuggingFace and loaded automatically via the model registry:

**[singhlab/plm_reverse_distillation](https://huggingface.co/singhlab/plm_reverse_distillation)**

## Available Models

All models use PCR regression and PCA for dimensionality reduction. Each model applies the full chain of scalers from ESM-2 8M up to the target size.

| Model name | Chain | Output dim |
| ---------- | ----- | ----------- |
| `esm2.rd/35M` | 8M → 35M | 480 |
| `esm2.rd/150M` | 8M → 35M → 150M | 640 |
| `esm2.rd/650M` | 8M → 35M → 150M → 650M | 1280 |
| `esm2.rd/3B` | 8M → 35M → 150M → 650M → 3B | 2560 |
| `esm2.rd/15B` | 8M → 35M → 150M → 650M → 3B → 15B | 5120 |

## Scripts

### Embedding extraction

Extract embeddings from a FASTA file using a pretrained RD model:

```bash
python scripts/extract.py \
    --fasta_file proteins.fasta \
    --output_dir embeddings/ \
    --repr_type mean \
    --batch_size 32
```

Key arguments: `--repr_type` (`per_tok` / `mean` / `bos`), `--repr_layers`, `--batch_size`, `--truncation_seq_length`.

### Training scalers

Train new scalers on your own data:

```bash
python scripts/train.py \
    --dataset_path proteins.fasta \
    --scalar_path scalers/ \
    --regressor_type pcr \
    --scaler_type rd \
    --n_pretrained_seqs 5000
```

Key arguments: `--regressor_type` (`linear` / `ridge` / `pcr`), `--scaler_type` (`rd` / `naive`), `--pca_type` (`incremental` / `fbpca`), `--n_pretrained_seqs`.

## Citation

If you use reverse distillation, please cite:

```bibtex
@inproceedings{catrina2026reverse,
  title   = {Reverse Distillation: Consistently Scaling Protein Language Model Representations},
  author  = {Catrina, Darius and Bepler, Christian and Sledzieski, Samuel and Singh, Rohit},
  booktitle = {International Conference on Learning Representations},
  year    = {2026}
}
```

## License

This project is licensed under the MIT License — see [LICENSE](LICENSE) for details.
