Metadata-Version: 2.1
Name: drlx
Version: 0.0.2
Summary: DRLX is a library for distributed training of diffusion models via RL
Author: CarperAI
License: MIT
Classifier: Development Status :: 3 - Alpha
Classifier: Environment :: Console
Classifier: Intended Audience :: Science/Research
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: MIT License
Classifier: Natural Language :: English
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: Programming Language :: Python :: Implementation :: CPython
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Typing :: Typed
Classifier: Operating System :: Unix
Requires-Python: >=3.9
Description-Content-Type: text/markdown
Requires-Dist: torch
Requires-Dist: torchvision
Requires-Dist: torchtyping
Requires-Dist: einops
Requires-Dist: diffusers
Requires-Dist: transformers
Requires-Dist: accelerate
Requires-Dist: xformers
Requires-Dist: wandb
Requires-Dist: fastprogress
Requires-Dist: matplotlib
Provides-Extra: benchmarks
Requires-Dist: pygraphviz ; extra == 'benchmarks'
Requires-Dist: graphviz ; extra == 'benchmarks'
Requires-Dist: openai ; extra == 'benchmarks'
Provides-Extra: dev
Requires-Dist: black ; extra == 'dev'
Requires-Dist: isort ; extra == 'dev'
Requires-Dist: flake8 ; extra == 'dev'
Requires-Dist: flake8-pyproject ; extra == 'dev'
Requires-Dist: pydocstyle ; extra == 'dev'
Requires-Dist: mypy ; extra == 'dev'
Requires-Dist: pre-commit ; extra == 'dev'
Requires-Dist: pytest ; extra == 'dev'
Requires-Dist: pytest-cov ; extra == 'dev'
Provides-Extra: docs
Requires-Dist: sphinx ==5.3.0 ; extra == 'docs'
Requires-Dist: sphinx-rtd-theme ; extra == 'docs'
Requires-Dist: sphinx-autodoc-typehints ; extra == 'docs'
Provides-Extra: notebook
Requires-Dist: ipython ; extra == 'notebook'
Provides-Extra: sodaracer
Requires-Dist: swig >=4.1.0 ; extra == 'sodaracer'
Requires-Dist: box2d-py ==2.3.8 ; extra == 'sodaracer'
Requires-Dist: pygame ; extra == 'sodaracer'
Provides-Extra: triton
Requires-Dist: tritonclient[all] ; extra == 'triton'

# Diffuser Reinforcement Learning X

DRLX is a library for distributed training of diffusion models via RL. It is meant to wrap around 🤗 Hugging Face's [Diffusers](https://huggingface.co/docs/diffusers/) library and uses [Accelerate](https://huggingface.co/docs/accelerate/) for Multi-GPU and Multi-Node (as of yet untested)

📖 **[Documentation](https://DRLX.readthedocs.io)**

# Setup

You can install the library from pypi:
```
pip install drlx
```

or from source:

```sh
pip install git+https://github.com/CarperAI/DRLX.git
```

# How to use

Currently we have only tested the library with StableDiffusion 1.4, but the plug and play nature of it means that realistically any denoiser from any pipeline should be usable. Models saved with DRLX are compatible with the pipeline they originated from and can be loaded like any other pretrained model. Currently the only algorithm supported for training is [DDPO](https://arxiv.org/abs/2305.13301).

```python
from drlx.reward_modelling.aesthetics import Aesthetics
from drlx.pipeline.pickapic_prompts import PickAPicPrompts
from drlx.trainer.ddpo_trainer import DDPOTrainer
from drlx.configs import DRLXConfig

# We import a reward model, a prompt pipeline, the trainer and config

pipe = PickAPicPrompts()
config = DRLXConfig.load_yaml("configs/my_cfg.yml")
trainer = DDPOTrainer(config)

trainer.train(pipe, Aesthetics())
```

And then to use a trained model for inference:

```python
pipe = StableDiffusionPipeline.from_pretrained("out/ddpo_exp")
prompt = "A mad panda scientist"
image = pipe(prompt).images[0]
image.save("test.jpeg")
```

# Accelerated Training

```bash
accelerate config
accelerate launch -m [your module]
```
