Metadata-Version: 2.1
Name: keras-aug
Version: 1.0.0
Summary: A library that includes Keras 3 preprocessing and augmentation layers
Author-email: Hong-Yu Chiu <james77777778@gmail.com>
Maintainer-email: Hong-Yu Chiu <james77777778@gmail.com>
License: Apache License 2.0
Project-URL: Homepage, https://github.com/james77777778/keras-aug
Project-URL: Documentation, https://github.com/james77777778/keras-aug
Project-URL: Repository, https://github.com/james77777778/keras-aug.git
Project-URL: Issues, https://github.com/james77777778/keras-aug/issues
Keywords: deep-learning,preprocessing,augmentation,keras,jax,tensorflow,torch
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: Operating System :: Unix
Classifier: Operating System :: MacOS
Classifier: Intended Audience :: Science/Research
Classifier: Topic :: Scientific/Engineering
Classifier: Topic :: Software Development
Requires-Python: >=3.9
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: keras
Provides-Extra: tests
Requires-Dist: isort; extra == "tests"
Requires-Dist: ruff; extra == "tests"
Requires-Dist: black; extra == "tests"
Requires-Dist: pytest; extra == "tests"
Requires-Dist: pytest-cov; extra == "tests"
Requires-Dist: coverage; extra == "tests"
Requires-Dist: pre-commit; extra == "tests"
Requires-Dist: namex; extra == "tests"

# KerasAug

<!-- markdownlint-disable MD033 -->

![Keras](https://img.shields.io/badge/keras-v3.4.1+-success.svg)
[![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/james77777778/keras-aug/actions.yml?label=tests)](https://github.com/james77777778/keras-aug/actions/workflows/actions.yml?query=branch%3Amain++)
[![codecov](https://codecov.io/gh/james77777778/keras-aug/branch/main/graph/badge.svg?token=81ELI3VH7H)](https://codecov.io/gh/james77777778/keras-aug)
[![PyPI](https://img.shields.io/pypi/v/keras-aug)](https://pypi.org/project/keras-aug/)
![PyPI - Downloads](https://img.shields.io/pypi/dm/keras-aug)
[![Open in HF Spaces](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm-dark.svg)](https://huggingface.co/spaces/james77777778/KerasAug)

## Description

KerasAug is a library that includes Keras 3 preprocessing and augmentation layers, providing support for various data types such as images, labels, bounding boxes, segmentation masks, and more.

<div align="center">
<img width="45%" src="https://github.com/user-attachments/assets/bf9488c4-5c6b-4c87-8fa8-30170a67c92c" alt="object_detection.gif"> <img width="45%" src="https://github.com/user-attachments/assets/556db949-9461-438a-b1cf-3621ec63416e"  alt="semantic_segmentation.gif">
</div>

> [!NOTE]
> See `docs/*.py` for the GIF generation. YOLOV8-like pipeline for bounding boxes and segmentation masks.

KerasAug aims to provide fast, robust and user-friendly preprocessing and augmentation layers, facilitating seamless integration with Keras 3 and `tf.data.Dataset`.

The APIs largely follow `torchvision`, and the correctness of the layers has been verified through unit tests.

Also, you can check out the demo app on HF:

Click here: [![Open in HF Spaces](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm-dark.svg)](https://huggingface.co/spaces/james77777778/KerasAug)

## Installation

```bash
pip install keras keras-aug -U
```

> [!IMPORTANT]  
> Make sure you have installed a supported backend for Keras.

## Quickstart

### Rock, Paper and Scissors Image Classification

```python
import keras
import tensorflow as tf
import tensorflow_datasets as tfds

from keras_aug import layers as ka_layers

BATCH_SIZE = 64
NUM_CLASSES = 3
INPUT_SIZE = (128, 128)

# Create a `tf.data.Dataset`-compatible preprocessing pipeline with all backends
train_dataset, validation_dataset = tfds.load(
    "rock_paper_scissors", as_supervised=True, split=["train", "test"]
)
train_dataset = (
    train_dataset.batch(BATCH_SIZE)
    .map(
        lambda images, labels: {
            "images": tf.cast(images, "float32") / 255.0,
            "labels": tf.one_hot(labels, NUM_CLASSES),
        }
    )
    .map(ka_layers.vision.Resize(INPUT_SIZE))
    .shuffle(128)
    .map(ka_layers.vision.RandAugment())
    .map(ka_layers.vision.CutMix(num_classes=NUM_CLASSES))
    .map(lambda data: (data["images"], data["labels"]))
    .prefetch(tf.data.AUTOTUNE)
)
validation_dataset = (
    validation_dataset.batch(BATCH_SIZE)
    .map(
        lambda images, labels: {
            "images": tf.cast(images, "float32") / 255.0,
            "labels": tf.one_hot(labels, NUM_CLASSES),
        }
    )
    .map(ka_layers.vision.Resize(INPUT_SIZE))
    .map(lambda data: (data["images"], data["labels"]))
    .prefetch(tf.data.AUTOTUNE)
)

# Create a CNN model
model = keras.models.Sequential(
    [
        keras.Input((*INPUT_SIZE, 3)),
        keras.layers.Conv2D(32, (3, 3), activation="relu"),
        keras.layers.MaxPooling2D(2, 2),
        keras.layers.Conv2D(64, (3, 3), activation="relu"),
        keras.layers.MaxPooling2D(2, 2),
        keras.layers.Conv2D(128, (3, 3), activation="relu"),
        keras.layers.MaxPooling2D(2, 2),
        keras.layers.Conv2D(256, (3, 3), activation="relu"),
        keras.layers.MaxPooling2D(2, 2),
        keras.layers.Flatten(),
        keras.layers.Dense(512, activation="relu"),
        keras.layers.Dense(NUM_CLASSES, activation="softmax"),
    ]
)
model.summary()
model.compile(
    loss="categorical_crossentropy",
    optimizer=keras.optimizers.AdamW(),
    metrics=["accuracy"],
)

# Train your model
model.fit(train_dataset, validation_data=validation_dataset, epochs=8)
```

The above example runs with all backends (JAX, TensorFlow, Torch).

### More Examples

- [YOLOV8 object detection pipeline](guides/voc_yolov8_aug.py)
- [YOLOV8 semantic segmentation pipeline](guides/oxford_yolov8_aug.py)

## Gradio App

```bash
gradio deploy
```

## Citing KerasAug

```bibtex
@misc{chiu2023kerasaug,
  title={KerasAug},
  author={Hongyu, Chiu},
  year={2023},
  howpublished={\url{https://github.com/james77777778/keras-aug}},
}
```
