Metadata-Version: 2.1
Name: balanced-loss
Version: 0.1.1
Summary: Easy to use class-balanced cross-entropy and focal loss implementation for Pytorch.
Home-page: https://github.com/fcakyon/balanced-loss
Author: fcakyon
License: MIT
Keywords: machine-learning,deep-learning,ml,pytorch,vision,loss,image-classification,video-classification
Classifier: Development Status :: 5 - Production/Stable
Classifier: Operating System :: OS Independent
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Topic :: Software Development :: Libraries
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Classifier: Topic :: Education
Classifier: Topic :: Scientific/Engineering
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.7
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch
Requires-Dist: numpy
Provides-Extra: tests
Requires-Dist: black==21.7b0; extra == "tests"
Requires-Dist: flake8==3.9.2; extra == "tests"
Requires-Dist: isort==5.9.2; extra == "tests"
Requires-Dist: click==8.0.4; extra == "tests"
Requires-Dist: importlib-metadata<4.3,>=1.1.0; python_version < "3.8" and extra == "tests"
Provides-Extra: dev
Requires-Dist: black==21.7b0; extra == "dev"
Requires-Dist: flake8==3.9.2; extra == "dev"
Requires-Dist: isort==5.9.2; extra == "dev"
Requires-Dist: click==8.0.4; extra == "dev"
Requires-Dist: importlib-metadata<4.3,>=1.1.0; python_version < "3.8" and extra == "dev"

<p align="center">
<img src="https://user-images.githubusercontent.com/34196005/180311379-1003da44-cdf9-46e8-af83-e65fbc3710cd.png" width="350">
</p>

<div align="center">
    <a href="https://badge.fury.io/py/balanced-loss"><img src="https://badge.fury.io/py/balanced-loss.svg" alt="pypi version"></a>
    <a href="https://pepy.tech/project/balanced-loss"><img src="https://pepy.tech/badge/balanced-loss" alt="total downloads"></a>
    <a href="https://twitter.com/fcakyon"><img src="https://img.shields.io/badge/twitter-fcakyon_-blue?logo=twitter&style=flat" alt="fcakyon twitter"></a>
</div>

<p align="center">
    Easy-to-use, class-balanced, cross-entropy and focal loss implementation for Pytorch.
</p>

## Theory

When training dataset labels are imbalanced, one thing to do is to balance the loss across sample classes.

- First, the effective number of samples are calculated for all classes as:

![alt-text](https://user-images.githubusercontent.com/34196005/180266195-aa2e8696-cdeb-48ed-a85f-7ffb353942a4.png)

- Then the class balanced loss function is defined as:

![alt-text](https://user-images.githubusercontent.com/34196005/180266198-e27d8cba-f5e1-49ca-9f82-d8656333e3c4.png)

## Installation

```bash
pip install balanced-loss
```

## Usage

- Standard losses:

```python
import torch
from balanced_loss import Loss

# outputs and labels
logits = torch.tensor([[0.78, 0.1, 0.05]]) # 1 batch, 3 class
labels = torch.tensor([0]) # 1 batch

# focal loss
focal_loss = Loss(loss_type="focal_loss")
loss = focal_loss(logits, labels)
```

```python
# cross-entropy loss
ce_loss = Loss(loss_type="cross_entropy")
loss = ce_loss(logits, labels)
```

```python
# binary cross-entropy loss
bce_loss = Loss(loss_type="binary_cross_entropy")
loss = bce_loss(logits, labels)
```

- Class-balanced losses:

```python
import torch
from balanced_loss import Loss

# outputs and labels
logits = torch.tensor([[0.78, 0.1, 0.05]]) # 1 batch, 3 class
labels = torch.tensor([0]) # 1 batch

# number of samples per class in the training dataset
samples_per_class = [30, 100, 25] # 30, 100, 25 samples for labels 0, 1 and 2, respectively

# class-balanced focal loss
focal_loss = Loss(
    loss_type="focal_loss",
    samples_per_class=samples_per_class,
    class_balanced=True
)
loss = focal_loss(logits, labels)
```

```python
# class-balanced cross-entropy loss
ce_loss = Loss(
    loss_type="cross_entropy",
    samples_per_class=samples_per_class,
    class_balanced=True
)
loss = ce_loss(logits, labels)
```

```python
# class-balanced binary cross-entropy loss
bce_loss = Loss(
    loss_type="binary_cross_entropy",
    samples_per_class=samples_per_class,
    class_balanced=True
)
loss = bce_loss(logits, labels)
```

- Customize parameters:

```python
import torch
from balanced_loss import Loss

# outputs and labels
logits = torch.tensor([[0.78, 0.1, 0.05]]) # 1 batch, 3 class
labels = torch.tensor([0])

# number of samples per class in the training dataset
samples_per_class = [30, 100, 25] # 30, 100, 25 samples for labels 0, 1 and 2, respectively

# class-balanced focal loss
focal_loss = Loss(
    loss_type="focal_loss",
    beta=0.999, # class-balanced loss beta
    fl_gamma=2, # focal loss gamma
    samples_per_class=samples_per_class,
    class_balanced=True
)
loss = focal_loss(logits, labels)
```

## Improvements

What is the difference between this repo and vandit15's?

- This repo is a pypi installable package
- This repo implements loss functions as `torch.nn.Module`
- In addition to class balanced losses, this repo also supports the standard versions of the cross entropy/focal loss etc. over the same API
- All typos and errors in vandit15's source are fixed
- Continiously tested on PyTorch 1.13.1 and 2.5.1

## References

https://arxiv.org/abs/1901.05555

https://github.com/richardaecn/class-balanced-loss

https://github.com/vandit15/Class-balanced-loss-pytorch
