Metadata-Version: 2.1
Name: dreamfinetune
Version: 1.5
Home-page: https://github.com/skillfi/fine-tuning
Author: Alex
License: Apache 2.0 License
Classifier: Development Status :: 5 - Production/Stable
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Education
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Operating System :: OS Independent
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Requires-Python: >=3.8
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: pillow>=9.4.0
Requires-Dist: diffusers~=0.30.0
Requires-Dist: transformers<4.45.0,>=4.42.4
Requires-Dist: tqdm~=4.66.4
Requires-Dist: datasets~=2.20.0
Requires-Dist: bitsandbytes
Requires-Dist: ftfy
Requires-Dist: gradio
Requires-Dist: tensorboard
Provides-Extra: torch
Requires-Dist: torch>=2.3.1; extra == "torch"
Requires-Dist: torchvision>=0.18.0; extra == "torch"
Requires-Dist: torchaudio>=2.3.1; extra == "torch"

StableDiffusionInpaintingFineTune
=================================

This project provides a toolkit for fine-tuning the Stable Diffusion model for inpainting tasks (image restoration based on a mask) using PyTorch and Hugging Face Diffusers libraries.

Requirements
------------

Before starting, you need to install the following libraries:
 .. code-block:: python

  pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

- ``torch``
- ``diffusers``
- ``transformers``
- ``accelerate``
- ``huggingface_hub``
- ``PIL``
- ``numpy``
- ``tqdm``

Description
-----------

StableDiffusionInpaintingFineTune
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

This class is responsible for fine-tuning the Stable Diffusion model for the inpainting task. It supports training both the text encoder and the UNet model and uses various settings to control the training process.

Constructor
^^^^^^^^^^^

.. code-block:: python

   __init__(self, pretrained_model_name_or_path, resolution, center_crop, ...)

- **pretrained_model_name_or_path**: The path or name of the pre-trained model.
- **resolution**: The resolution of the images.
- **center_crop**: Whether to apply center cropping during data preparation.
- **train_text_encoder**: Whether to train the text encoder.
- **dataset**: The dataset object.
- **learning_rate**: The initial learning rate.
- **max_training_steps**: The maximum number of training steps.
- **save_steps**: The number of steps between saving checkpoints.
- **train_batch_size**: The batch size.
- **gradient_accumulation_steps**: The number of steps to accumulate gradients.
- **mixed_precision**: Use of mixed precision ("fp16", "bf16", or None).
- **gradient_checkpointing**: Use of gradient checkpointing.
- **use_8bit_adam**: Use of the 8-bit Adam optimizer.
- **seed**: The random seed for reproducibility.
- **output_dir**: The directory for saving results.
- **push_to_hub**: Whether to upload the results to the Hugging Face Hub.
- **repo_id**: The repository ID on Hugging Face Hub.

Methods
^^^^^^^

- **prepare_mask_and_masked_image(image, mask)**: Prepares the mask and masked image.
- **random_mask(im_shape, ratio=1, mask_full_image=False)**: Generates a random mask.
- **load_args_for_training()**: Loads the necessary components of the model for training.
- **collate_fn(examples)**: Forms a batch of data for the model.
- **__call__(self, *args, **kwargs)**: The main method for running the training process.

Usage
-----

To start training, you should create an instance of the ``StableDiffusionInpaintingFineTune`` class and call its ``__call__`` method, passing the necessary arguments.

.. code-block:: python

   model = StableDiffusionInpaintingFineTune(
       pretrained_model_name_or_path="path_to_model",
       resolution=512,
       center_crop=True,
       ...
   )

   model()

License
-------

The project is distributed under the MIT License.
