Metadata-Version: 2.1
Name: mlorax
Version: 0.9.0
Summary: MLoRAx is a minimalist library for low-rank adaptation designd to effortlessly enable parameter-efficient training for Transformer-based models.
Home-page: https://github.com/yongchanghao/MLoRAx
Author: Yongchang Hao
License: Apache 2.0
Description-Content-Type: text/markdown
License-File: LICENSE.md
Requires-Dist: jax
Requires-Dist: flax
Requires-Dist: chex


## Installation

### Choice 1: Just copy the code
You can just copy the code from `mlorax.py` and paste it into your project. This is the easiest way to use `mlorax` if you do not care about future updates.

### Choice 2: Install with pip
You can also install `mlorax` with pip. This is the recommended way to use `mlorax` if you want to receive future updates.
```bash
pip install mlorax
```

### Choice 3: Install from source
You can also install `mlorax` from source. You only need to do this if you want to contribute to the project.
```bash
git clone https://github.com/yongchanghao/MLoRAx.git
cd MLoRAx
pip install -e .
```

## Usage
It is extremely easy to use `mlorax` to convert any Flax model to a LoRA model. The following code snippet shows how to convert a T5 model to a LoRA model based on HuggingFace's [FlaxT5ForConditionalGeneration](https://huggingface.co/docs/transformers/model_doc/t5#transformers.FlaxT5ForConditionalGeneration) class.

```diff
+ import mlorax
model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
- params = model.params
- apply_fn = model.__call__
+ lora_spec = mlorax.LoRASpec(rank=16, rules=['Attention.q', 'Attention.v'])
+ params, apply_fn, merge_fn = mlorax.lora_init(lora_spec, model)
state = TrainState(apply_fn=apply_fn, params=params, tx=tx, **kwargs)
```

That's it! You can now train the model as usual.

### Principles
Always use the **returned** `apply_fn` for model forwarding if possible. Otherwise use `params=merge_fn(params)` to pass the merged parameters to the function. For example, if you want to use `model.generate` for text generation, you can do the following:
```diff
- outputs = model.generate(**batch, params=params)
+ outputs = model.generate(**batch, params=merge_fn(params))
```



## Example and Results

For the code used in the example, check out the [examples](examples) folder.




## Citation
If you find MLoRAx useful, please cite the following paper:

```bibtex
@software{hao2023mlorax,
  author = {Yongchang Hao},
  title = {{ML}o{RA}x: a minimalist library for low-rank adaptation for {T}ransformer-based models},
  year = {2023},
  url = {https://github.com/yongchanghao/MLoRAx},
  version = {0.9.0}
}
```
