Metadata-Version: 2.4
Name: flash-pref
Version: 0.1.1
Summary: Accelerate LLM preference finetuning with a single line of code
Author-email: Jiahao Li <liplus17@163.com>
Maintainer-email: Jiahao Li <liplus17@163.com>
License: MIT License
Project-URL: Homepage, https://github.com/li-plus/flash-preference
Project-URL: Repository, https://github.com/li-plus/flash-preference.git
Keywords: Prefix Sharing,Direct Preference Optimization,Reward Model,Preference Tuning
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Science/Research
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Requires-Python: >=3.7
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: transformers<4.52.0,>=4.49.0
Dynamic: license-file

# Flash Preference

[![PyPI](https://img.shields.io/pypi/v/flash-pref)](https://pypi.org/project/flash-pref/)
[![License: MIT](https://img.shields.io/badge/license-MIT-blue)](LICENSE)

Accelerate LLM preference tuning via prefix sharing with a single line of code. Applicable to Direct Preference Optimization (DPO), Reward Modeling (RM), Group Relative Policy Optimization (GRPO), etc.

![](docs/prefix_sharing.png)

## Getting Started

Install the stable version from PyPI:
```
pip install flash-pref
```

Or install the latest version from GitHub:
```sh
pip install git+https://github.com/li-plus/flash-preference.git@main
```

All you have to do is to add a `shared_prefix` context wrapping the model forward and backward passes. The common prefixes of the input sequences will be automatically detected and shared, reducing computation and memory footprint without loss of accuracy.
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from flash_pref import shared_prefix

model_id = "Qwen/Qwen2.5-7B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="right")
model = AutoModelForCausalLM.from_pretrained(
    model_id, attn_implementation="flash_attention_2", use_cache=False, torch_dtype=torch.bfloat16, device_map="cuda"
)

prompt = "What are the next 10 numbers of this sequence: " + ", ".join(str(x) for x in range(500))
chosen_response = ", ".join(str(x) for x in range(500, 500 + 10))
rejected_response = ", ".join(str(x) for x in range(500, 500 + 10, 2))

conversations = [
    [{"role": "user", "content": prompt}, {"role": "assistant", "content": chosen_response}],
    [{"role": "user", "content": prompt}, {"role": "assistant", "content": rejected_response}],
]
inputs = tokenizer.apply_chat_template(
    conversations, tokenize=True, padding=True, return_tensors="pt", return_dict=True
).to("cuda")

# ===== MAGIC HERE =====
with shared_prefix(model, input_ids=inputs.input_ids, attention_mask=inputs.attention_mask):
    output = model(**inputs)
    output.logits.backward(torch.randn_like(output.logits))
```

For [huggingface/trl](https://github.com/huggingface/trl) users, a drop-in replacement for trl trainer is also available. Check out the end-to-end training examples below.

| Algorithm                      | Original Trainer                | Accelerated Trainer with Prefix Sharing     | Example                                                            |
| ------------------------------ | ------------------------------- | ------------------------------------------- | ------------------------------------------------------------------ |
| Direct Preference Optimization | `from trl import DPOTrainer`    | `from flash_pref import FlashDPOTrainer`    | [examples/dpo_trl.py](examples/dpo_trl.py)                         |
| Reward Modeling                | `from trl import RewardTrainer` | `from flash_pref import FlashRewardTrainer` | [examples/reward_modeling_trl.py](examples/reward_modeling_trl.py) |

## Benchmark

The performance speedup and memory saved relative to the baseline:

![](docs/perf.png)

Benchmark settings are as below. Please refer to [tests/benchmark.py](tests/benchmark.py) for more details.
* Model: [Qwen/Qwen2.5-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct) with gradient checkpointing, [Liger-Kernel](https://github.com/linkedin/Liger-Kernel) and [FlashAttention-2](https://github.com/Dao-AILab/flash-attention) enabled.
* Data: mocked pairwise preference data where prompt and response lengths vary from 64 to 16k.
* Computation: 1 forward pass and then 1 backward pass.
* Hardware: 1x NVIDIA A800-SXM4-80GB GPU.

## Developing

**Unit Tests**

Currently tested for LLaMA, Gemma, Gemma2, Qwen2, Qwen2VL and Qwen2.5VL architectures. At least 2 GPUs are required for unit tests. To run the unit tests, type:
```sh
make test
```

**Code Format**

To format the code, type:
```sh
make lint
```

## License

This project is under [MIT License](LICENSE).
