Metadata-Version: 2.1
Name: mlx-image
Version: 0.1.1
Summary: Apple mlx image models library
Author: Riccardo Musmeci
Author-email: riccardomusmeci92@gmail.com
Requires-Python: >=3.10,<4.0
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Provides-Extra: dev
Provides-Extra: test
Requires-Dist: Pillow
Requires-Dist: albumentations
Requires-Dist: black[jupyter] (==23.3.0) ; extra == "dev"
Requires-Dist: docstr-coverage (==2.2.0) ; extra == "dev"
Requires-Dist: huggingface_hub
Requires-Dist: matplotlib
Requires-Dist: mlx
Requires-Dist: mypy (==1.2.0) ; extra == "dev"
Requires-Dist: numpy (==1.26.2)
Requires-Dist: opencv-python
Requires-Dist: pandas
Requires-Dist: pre-commit (==3.3.1) ; extra == "dev"
Requires-Dist: pytest (==7.3.1) ; extra == "test"
Requires-Dist: pytest-cov (==4.0.0) ; extra == "test"
Requires-Dist: pytest-sugar (==0.9.7) ; extra == "test"
Requires-Dist: pytest-xdist (==3.2.1) ; extra == "test"
Requires-Dist: ruff (==0.0.264) ; extra == "dev"
Requires-Dist: safetensors
Requires-Dist: torch (==2.0)
Requires-Dist: tqdm
Requires-Dist: twine (>=4.0.0,<5) ; extra == "dev"
Requires-Dist: types-pyyaml (>=6.0.12.12,<7.0.0.0)
Description-Content-Type: text/markdown

# **mlxim**
Image models based on [Apple MLX framework](https://github.com/ml-explore/mlx) for Apple Silicon machines.

## **Why? 💡**

Apple MLX framework is a great tool to run machine learning models on Apple Silicon machines.

This repository is meant to convert image models from timm/torchvision to Apple MLX framework. The weights are just converted from .pth to .npz and the models **are not trained again**.

I don't have enough compute power (and time) to train all the models from scratch (**someone buy me a maxed-out Mac, please**).

## **How to install**
```
pip install mlx-image
```

## **Models**
Models weights are available on [`mlx-vision`](https://huggingface.co/mlx-vision) space on HuggingFace.

To create a model with weights:
```python
from mlxim.model import create_model

# loading weights from HuggingFace
model = create_model("resnet18")

# loading weights from local file
model = create_model("resnet18", weights="path/to/weights.npz")
```

To list all available models:
```python
from mlxim.model import list_models
list_models()
```
> [!WARNING]
> As of today (2024-03-05) mlx does not support nn.Conv2d with `group` or `dilation` greater than 1 (e.g. `resnext`, `regnet`, `efficientnet`).

## **ImageNet-1K Results**
Go to [results-imagenet-1k.csv](results/results-imagenet-1k.csv) to check every model converted and its performance on ImageNet-1K with different settings.

## **Train**
Training is similar to PyTorch, thanks to some tools `mlx-im` provides. Here's an example of how to train a model with `mlx-im`:

```python
import mlx.nn as nn
import mlx.optimizers as optim
from mlxim.model import create_model
from mlxim.data import LabelFolderDataset, DataLoader

train_dataset = LabelFolderDataset(
    root_dir="path/to/train",
    class_map={0: "class_0", 1: "class_1", 2: ["class_2", "class_3"]}
)
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4
)
model = create_model("resnet18") # pretrained weights loaded from HF
optimizer = optim.Adam(learning_rate=1e-3)

def train_step(model, inputs, targets):
    logits = model(inputs)
    loss = mx.mean(nn.losses.cross_entropy(logits, target))
    return loss

model.train()
for epoch in range(10):
    for batch in train_loader:
        x, target = batch
        train_step_fn = nn.value_and_grad(model, train_step)
        loss, grads = train_step_fn(x, target)
        optimizer.update(model, grads)
        mx.eval(model.state, optimizer.state)
```

## **Validation**

The `validation.py` script is run every time a pth model is converted to mlx and it's used to check if the model performs similarly to the original one on ImageNet-1K.

I use the configuration file `config/validation.yaml` to set the parameters for the validation script.

You can download the ImageNet-1K validation set from mlx-vision space on HuggingFace at this [link](https://huggingface.co/datasets/mlx-vision/imagenet-1k).

## **Similarity to PyTorch and other familiar tools**
`mlx-im` tries to be as close as possible to PyTorch:
- `DataLoader` -> you can define your own `collate_fn` and also use `num_workers` to speed up the data loading
- `Dataset` -> `mlx-im` already supports `LabelFolderDataset` (the good and old PyTorch `ImageFolder`) and `FolderDataset` (a generic folder with images in it)
- `ModelCheckpoint` -> keeps track of the best model and saves it to disk (similar to PyTorchLightning) and it also suggests early stopping

## **Contributing**

This is a work in progress, so any help is appreciated.

I am working on it in my spare time, so I can't guarantee frequent updates.

If you love coding and want to contribute, follow the instructions in [CONTRIBUTING.md](CONTRIBUTING.md).

## **To-Do**

[ ] add register_model (similar to timm)

[ ] inference script (similar to train/validation)

[ ] SwinTransformer

[ ] DenseNet

[ ] MobileNet


## Contact

If you have any questions, please email `riccardomusmeci92@gmail.com`.

