Metadata-Version: 2.3
Name: loadax
Version: 0.1.0
Summary: Dataloading for Jax
Author-email: Nick Wall <46641379+walln@users.noreply.github.com>
License: MIT License
        
        Copyright (c) 2023 Alex Johansson
        
        Permission is hereby granted, free of charge, to any person obtaining a copy
        of this software and associated documentation files (the "Software"), to deal
        in the Software without restriction, including without limitation the rights
        to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
        copies of the Software, and to permit persons to whom the Software is
        furnished to do so, subject to the following conditions:
        
        The above copyright notice and this permission notice shall be included in all
        copies or substantial portions of the Software.
        
        THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
        IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
        FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
        AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
        LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
        OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
        SOFTWARE.
License-File: LICENSE
Requires-Python: >=3.8
Requires-Dist: jax>=0.4.15
Description-Content-Type: text/markdown

# Loadax

Loadax is a dataloading library designed for the JAX ecosystem. It provides utilities for feeding data into your training loop without having to worry about batching, shuffling, and other preprocessing steps. Loadax also supports offloading data loading to the background, and prefetching a cache to improve performance, and jax-native distributed data loading.

[!Important] Loadax is currently in early development, and the rest of this README is a working draft.

## Installation

```bash
pip install loadax
```

## Usage

### Data Loading

Loadax provides a simple interface for loading data into your training loop. Here is an example of loading data from a list of items:

```python
from loadax import DataLoader, InMemoryDataset, Batcher

dataset = InMemoryDataset([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
batcher = Batcher(lambda x: x)
loader = DataLoader(batcher).batch_size(2).build(dataset)

for batch in loader:
    print(batch)

# Output:
# [1, 2]
# [3, 4]
# [5, 6]
# [7, 8]
# [9, 10]
```

A dataloader is a definition of how to load data from a dataset. It itself is stateless enabling you to define mutliple dataloaders for the same dataset, and even multipple iterators for the same dataloader.

```python
dataloader = DataLoader(batcher).batch_size(2).build(dataset)

fast_iterator = iter(dataloader)
slow_iterator = iter(dataloader)

val = next(fast_iterator)
print(val)
# Output: 1

val = next(slow_iterator)
print(val)
# Output: 1
```

In the above examples we create an object called a batcher. A batcher is an interface that defines how to collate your data into batches. This is useful for when you want to alter the way your data is batched such as stacking into a single array.

### Data Prefetching

When training models, it is essential to ensure that you are not blocking the training loop and especially your accelerator(s), with IO bound tasks. Loadax provides a simple interface for prefetching data into a cache using background worker(s).

```python
from loadax import DataLoader, InMemoryDataset, Batcher

dataset = InMemoryDataset([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
batcher = Batcher(lambda x: x)
loader = DataLoader(batcher).batch_size(2).prefetch(3).build(dataset)

for batch in loader:
    print(batch)

# Output:
# [1, 2]
# [3, 4]
# [5, 6]
# [7, 8]
# [9, 10]
```

In the above example we create a dataloader with a prefetch factor of 3. This means that the loader will prefetch 3 batches ahead of the current index. The future batches are kept in a cache, depending on your configuration can be eagerly loaded into device memory or kept in host memory.

### Using Multiple Workers

In the same way that the dataloader can be used to prefetch data, it can also offload the dataloading into multiple background workers. Lets take a look at an example of why you may want to do this.

In the following example we have a dataset that is slow to load an individual item due to some pre-processing. Ignore the details of the MappedDataset as we will get to that later, for now just know that it lazily transforms the data from the source dataset.

```python
from loadax import DataLoader, RangeDataset, MappedDataset, Batcher

def slow_fn(x):
    time.sleep(0.1)
    return x * 2

dataset = MappedDataset(RangeDataset(0, 10), slow_fn)
batcher = Batcher(lambda x: x)
loader = DataLoader(batcher).batch_size(2).workers(2).build(dataset)

for batch in loader:
    print(batch)

# Output:
# [0, 2]
# [4, 6]
# [8, 10]
# [12, 14]
# [16, 18]
```

In the above example we create a dataloader with 2 workers. This means that the loader will create 2 workers to load the data. The data is loaded in parallel, alowing the background workers to do the slow processing and then the data is batched and ready for consumption.

A important note is that the implementation of the background workers currently leverages the `concurrent.futures` library, because multiprocessing does not work well with JAX. This means each node is using a single python process and depending on your python version and how IO bound your datset loading is you may rarely see GIL contention.

### Distributed Data Loading

Loadax also supports distributed data loading. This means that you can easily shard your dataset across multiple nodes/jax processes and load data in parallel. Loadax will automatically determine which elements to load on each shard within the network ensuring that the data is evenly distributed, and each node only gets the data it needs.

With the inter-node distribution handled for you, it is now trivial to build advanced distributed training loops with paradigms such as model and data parallelism.

```python
from loadax import DataLoader, InMemoryDataset, Batcher
from jax.sharding import Mesh, PartitionSpec, NamedSharding
import jax.numpy as jnp

# Create a mesh across all the jax devices
mesh = Mesh(jax.devices(), ("data", "model"))

# Create a partition spec for the mesh
partition_spec = PartitionSpec("data", "model")

dataset_size = 100
batch_size = 10

# Create dataloader for a jax process
dataset = InMemoryDataset(list(range(dataset_size)))
batcher = Batcher(lambda x: jnp.stack(x))

dataloader = (
    DataLoader(batcher)
        .batch_size(batch_size)
        .workers(2)
        .prefetch(2)
        .shard(mesh, partition_spec)
        .build(dataset)
    )

# Define a simple model function, you can imagine this is some Flax model or something similar, it may even be sharded itself in some other axis such as model parallelism
def simple_model(x, params):
    return x * params

params = jnp.array([2.0])
sharded_params = jax.device_put(params, NamedSharding(mesh, partition_spec))

def compute_loss(batch, predictions):
    # Your loss calculation logic
    return jnp.mean(...)

for batch in dataloader:
    # Distribute the batch across the local devices
    local_batch = jax.device_put(jnp.array(batch), NamedSharding(mesh, sharding_spec))

    # Apply the model and compute the local loss
    predictions = jax.jit(simple_model)(local_batch, sharded_params)
    loss = compute_loss(local_batch, predictions)

    total_loss += jax.lax.pmean(loss, axis_name="model")
```

The sharding primitives that Loadax provides are powerful as they declare the way data is distributed up front. This enables loadax to be deterministic as is decides which elements to load on each shard, and even which elements to load into each specific batch. This guaranteed determinism enables you to focus on other things rather than ensuring that your dataloading is correct and can be reproduced.

### Type Hinting

Another benefit of Loadax is that the underlying shape of your data is passed through all the way into your training loop. This means you can use type hints to ensure that your data is the correct shape.

```python
from loadax import DataLoader, RangeDataset, Batcher

# RangeDataset has a DatasetItem type of Int, this is a generic argument that can be supplied to any dataset
# type. We can look more into this when we get to datasets.
dataset = RangeDataset(0, 10)

# this function is inferred to return an int
def my_fn(x: list[int]) -> int:
    return sum(x)

batcher = Batcher(my_fn)
loader = DataLoader(batcher).batch_size(2).build(dataset)

for batch in loader:
    print(batch)

# Output:
# [1, 2]
# [3, 4]
# [5, 6]
# [7, 8]
# [9, 10]
```

Because you define the Batcher (or use a predefined one for common operations), the type of the batch can be inferred all the way from the dataset definition.

### Datasets

Loadax provides a simple interface for defining your dataset. As long as you can perform indexed access on your data, you can use Loadax to load your data. See the [Dataset Protocol](https://github.com/walln/loadax/blob/main/src/loadax/dataset/protocol.py) for more details.

Additionally, Loadax provides a few common datasets that can be used out of the box. These include:

- InMemoryDataset
- RangeDataset

```python
dataset = InMemoryDataset([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
dataset = RangeDataset(0, 10)
```

Daasets can also be transformed using a variety of transformations. Transformations are lazily applied to the dataset, meaning that they are only applied when the data is actually accessed. Because your dataloader likely is prefetching and using background workers, this should not block your training loop. This also means that you can use jax to jit compile your transformation function.

```python
from loadax import MappedDataset, RangeDataset, ShuffledDataset

def slow_fn(x):
    time.sleep(0.1)
    return x * 2

base_dataset = ShuffledDataset(RangeDataset(0, 10))
dataset = MappedDataset(base_dataset, slow_fn)
```

When iterating through `dataset`, the the slow_fn will be applied lazily to the underlying dataset, which in itself is lazily shuffling the range dataset. This Composable pattern allows you to build complex dataloading pipelines.

#### Dataset Integrations

Loadax has a few common dataset source on the roadmap, including:

- PolarsDataset
- SQLiteDataset
- HuggingFaceDataset

Feel free to open an issue if you have a use case that you would like to see included.

### Batchers

Batchers are used to define how to collate your data into batches.
